Skip to main content

workflow_graph_queue/
scheduler.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use workflow_graph_shared::{JobStatus, Workflow};
7
8use crate::error::SchedulerError;
9use crate::traits::*;
10
11/// Shared workflow state, readable by the frontend polling API.
12pub type SharedState = Arc<RwLock<WorkflowState>>;
13
14/// In-memory workflow state that the frontend reads via polling.
15pub struct WorkflowState {
16    pub workflows: HashMap<String, Workflow>,
17}
18
19impl WorkflowState {
20    pub fn new() -> Self {
21        Self {
22            workflows: HashMap::new(),
23        }
24    }
25}
26
27impl Default for WorkflowState {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33/// Event-driven DAG scheduler.
34///
35/// Listens for `JobEvent`s from the queue and enqueues downstream jobs
36/// when their dependencies are satisfied. Replaces the inline orchestrator.
37pub struct DagScheduler<Q: JobQueue, A: ArtifactStore> {
38    queue: Arc<Q>,
39    artifacts: Arc<A>,
40    state: SharedState,
41}
42
43impl<Q: JobQueue, A: ArtifactStore> DagScheduler<Q, A> {
44    pub fn new(queue: Arc<Q>, artifacts: Arc<A>, state: SharedState) -> Self {
45        Self {
46            queue,
47            artifacts,
48            state,
49        }
50    }
51
52    /// Initiate a workflow run: reset all jobs to Queued, then enqueue root jobs.
53    pub async fn start_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
54        let root_jobs = {
55            let mut state = self.state.write().await;
56            let wf = state
57                .workflows
58                .get_mut(workflow_id)
59                .ok_or_else(|| SchedulerError::WorkflowNotFound(workflow_id.to_string()))?;
60
61            // Reset all jobs
62            for job in &mut wf.jobs {
63                job.status = JobStatus::Queued;
64                job.duration_secs = None;
65                job.started_at = None;
66                job.output = None;
67            }
68
69            // Find root jobs (no dependencies)
70            wf.jobs
71                .iter()
72                .filter(|j| j.depends_on.is_empty())
73                .map(|j| (j.id.clone(), j.command.clone()))
74                .collect::<Vec<_>>()
75        };
76
77        // Enqueue root jobs
78        for (job_id, command) in root_jobs {
79            let queued = QueuedJob {
80                job_id,
81                workflow_id: workflow_id.to_string(),
82                command,
83                required_labels: vec![],
84                retry_policy: RetryPolicy::default(),
85                attempt: 0,
86                upstream_outputs: HashMap::new(),
87                enqueued_at_ms: now_ms(),
88                delayed_until_ms: 0,
89            };
90            self.queue.enqueue(queued).await?;
91        }
92
93        Ok(())
94    }
95
96    /// Cancel a running workflow: cancel all pending/active jobs.
97    pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
98        self.queue.cancel_workflow(workflow_id).await?;
99
100        let mut state = self.state.write().await;
101        if let Some(wf) = state.workflows.get_mut(workflow_id) {
102            for job in &mut wf.jobs {
103                if job.status == JobStatus::Queued || job.status == JobStatus::Running {
104                    job.status = JobStatus::Cancelled;
105                }
106            }
107        }
108        Ok(())
109    }
110
111    /// Run the scheduler event loop. Listens for queue events and drives the DAG.
112    /// This should be spawned as a background task.
113    pub async fn run(self: Arc<Self>) {
114        let mut rx = self.queue.subscribe();
115
116        loop {
117            let event = match rx.recv().await {
118                Ok(event) => event,
119                Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120                    eprintln!("Scheduler lagged by {n} events, some jobs may need manual recovery");
121                    continue;
122                }
123                Err(tokio::sync::broadcast::error::RecvError::Closed) => {
124                    eprintln!("Queue event channel closed, scheduler shutting down");
125                    break;
126                }
127            };
128
129            if let Err(e) = self.handle_event(event).await {
130                eprintln!("Scheduler error: {e}");
131            }
132        }
133    }
134
135    async fn handle_event(&self, event: JobEvent) -> Result<(), SchedulerError> {
136        match event {
137            JobEvent::Started {
138                workflow_id,
139                job_id,
140                ..
141            } => {
142                self.on_job_started(&workflow_id, &job_id).await;
143            }
144            JobEvent::Completed {
145                workflow_id,
146                job_id,
147                outputs,
148            } => {
149                self.on_job_completed(&workflow_id, &job_id, outputs)
150                    .await?;
151            }
152            JobEvent::Failed {
153                workflow_id,
154                job_id,
155                error,
156                retryable,
157            } => {
158                self.on_job_failed(&workflow_id, &job_id, &error, retryable)
159                    .await;
160            }
161            JobEvent::LeaseExpired {
162                workflow_id,
163                job_id,
164                ..
165            } => {
166                self.on_lease_expired(&workflow_id, &job_id).await;
167            }
168            JobEvent::Cancelled {
169                workflow_id,
170                job_id,
171            } => {
172                self.on_job_cancelled(&workflow_id, &job_id).await;
173            }
174            JobEvent::Ready { .. } => {
175                // No action needed — job is in queue waiting for a worker
176            }
177        }
178        Ok(())
179    }
180
181    async fn on_job_started(&self, workflow_id: &str, job_id: &str) {
182        let mut state = self.state.write().await;
183        if let Some(wf) = state.workflows.get_mut(workflow_id)
184            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
185        {
186            job.status = JobStatus::Running;
187            job.started_at = Some(now_ms() as f64);
188        }
189    }
190
191    async fn on_job_completed(
192        &self,
193        workflow_id: &str,
194        job_id: &str,
195        outputs: HashMap<String, String>,
196    ) -> Result<(), SchedulerError> {
197        // Store outputs
198        self.artifacts
199            .put_outputs(workflow_id, job_id, outputs)
200            .await?;
201
202        // Update state
203        let ready_jobs = {
204            let mut state = self.state.write().await;
205            let wf = match state.workflows.get_mut(workflow_id) {
206                Some(wf) => wf,
207                None => return Ok(()),
208            };
209
210            // Mark this job as success
211            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
212                job.status = JobStatus::Success;
213                if let Some(started) = job.started_at {
214                    job.duration_secs =
215                        Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
216                }
217            }
218
219            // Find downstream jobs whose deps are ALL succeeded
220            let ready: Vec<(String, String, Vec<String>)> = wf
221                .jobs
222                .iter()
223                .filter(|j| {
224                    j.status == JobStatus::Queued
225                        && j.depends_on.contains(&job_id.to_string())
226                        && j.depends_on.iter().all(|dep| {
227                            wf.jobs
228                                .iter()
229                                .find(|dj| dj.id == *dep)
230                                .is_some_and(|dj| dj.status == JobStatus::Success)
231                        })
232                })
233                .map(|j| (j.id.clone(), j.command.clone(), j.depends_on.clone()))
234                .collect();
235
236            ready
237        };
238
239        // Enqueue ready downstream jobs with upstream outputs
240        for (next_id, command, deps) in ready_jobs {
241            let upstream_outputs = self
242                .artifacts
243                .get_upstream_outputs(workflow_id, &deps)
244                .await?;
245
246            let queued = QueuedJob {
247                job_id: next_id,
248                workflow_id: workflow_id.to_string(),
249                command,
250                required_labels: vec![],
251                retry_policy: RetryPolicy::default(),
252                attempt: 0,
253                upstream_outputs,
254                enqueued_at_ms: now_ms(),
255                delayed_until_ms: 0,
256            };
257            self.queue.enqueue(queued).await?;
258        }
259
260        Ok(())
261    }
262
263    async fn on_job_failed(&self, workflow_id: &str, job_id: &str, error: &str, retryable: bool) {
264        let mut state = self.state.write().await;
265        let Some(wf) = state.workflows.get_mut(workflow_id) else {
266            return;
267        };
268
269        if retryable {
270            // Job was re-enqueued by the queue — mark as queued again
271            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
272                job.status = JobStatus::Queued;
273                job.started_at = None;
274            }
275        } else {
276            // Permanent failure — mark job and skip all downstream
277            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
278                job.status = JobStatus::Failure;
279                job.output = Some(error.to_string());
280                if let Some(started) = job.started_at {
281                    job.duration_secs =
282                        Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
283                }
284            }
285
286            // Skip transitive downstream
287            let skip_ids = find_transitive_downstream(wf, job_id);
288            for skip_id in &skip_ids {
289                if let Some(j) = wf.jobs.iter_mut().find(|j| j.id == *skip_id) {
290                    j.status = JobStatus::Skipped;
291                }
292            }
293        }
294    }
295
296    async fn on_lease_expired(&self, workflow_id: &str, job_id: &str) {
297        // The queue already handled re-enqueueing if retries remain.
298        // We just need to update the state if the job was permanently failed.
299        let mut state = self.state.write().await;
300        if let Some(wf) = state.workflows.get_mut(workflow_id)
301            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
302        {
303            // If the queue re-enqueued it, mark as queued; otherwise mark as failure
304            // We can't easily know here, so mark as Queued — the next Started event
305            // will update it correctly.
306            job.status = JobStatus::Queued;
307            job.started_at = None;
308        }
309    }
310
311    async fn on_job_cancelled(&self, workflow_id: &str, job_id: &str) {
312        let mut state = self.state.write().await;
313        if let Some(wf) = state.workflows.get_mut(workflow_id)
314            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
315        {
316            job.status = JobStatus::Cancelled;
317        }
318    }
319}
320
321/// Find all jobs transitively downstream of the given job.
322fn find_transitive_downstream(wf: &Workflow, job_id: &str) -> Vec<String> {
323    let mut result = Vec::new();
324    let mut stack = vec![job_id.to_string()];
325
326    while let Some(current) = stack.pop() {
327        for job in &wf.jobs {
328            if job.depends_on.contains(&current) && !result.contains(&job.id) {
329                result.push(job.id.clone());
330                stack.push(job.id.clone());
331            }
332        }
333    }
334
335    result
336}
337
338fn now_ms() -> u64 {
339    std::time::SystemTime::now()
340        .duration_since(std::time::UNIX_EPOCH)
341        .unwrap_or_default()
342        .as_millis() as u64
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::memory::{InMemoryArtifactStore, InMemoryJobQueue};
349
350    fn sample_workflow() -> Workflow {
351        Workflow {
352            id: "wf1".into(),
353            name: "test".into(),
354            trigger: "manual".into(),
355            jobs: vec![
356                workflow_graph_shared::Job {
357                    id: "a".into(),
358                    name: "Job A".into(),
359                    status: JobStatus::Queued,
360                    command: "echo a".into(),
361                    duration_secs: None,
362                    started_at: None,
363                    required_labels: vec![],
364                    max_retries: 0,
365                    attempt: 0,
366                    depends_on: vec![],
367                    output: None,
368                    metadata: std::collections::HashMap::new(),
369                },
370                workflow_graph_shared::Job {
371                    id: "b".into(),
372                    name: "Job B".into(),
373                    status: JobStatus::Queued,
374                    command: "echo b".into(),
375                    duration_secs: None,
376                    started_at: None,
377                    required_labels: vec![],
378                    max_retries: 0,
379                    attempt: 0,
380                    depends_on: vec!["a".into()],
381                    output: None,
382                    metadata: std::collections::HashMap::new(),
383                },
384                workflow_graph_shared::Job {
385                    id: "c".into(),
386                    name: "Job C".into(),
387                    status: JobStatus::Queued,
388                    command: "echo c".into(),
389                    duration_secs: None,
390                    started_at: None,
391                    required_labels: vec![],
392                    max_retries: 0,
393                    attempt: 0,
394                    depends_on: vec!["a".into()],
395                    output: None,
396                    metadata: std::collections::HashMap::new(),
397                },
398            ],
399        }
400    }
401
402    async fn setup() -> (
403        Arc<DagScheduler<InMemoryJobQueue, InMemoryArtifactStore>>,
404        Arc<InMemoryJobQueue>,
405        SharedState,
406    ) {
407        let queue = Arc::new(InMemoryJobQueue::new());
408        let artifacts = Arc::new(InMemoryArtifactStore::new());
409        let state = Arc::new(RwLock::new(WorkflowState::new()));
410
411        state
412            .write()
413            .await
414            .workflows
415            .insert("wf1".into(), sample_workflow());
416
417        let scheduler = Arc::new(DagScheduler::new(
418            queue.clone(),
419            artifacts.clone(),
420            state.clone(),
421        ));
422
423        (scheduler, queue, state)
424    }
425
426    #[tokio::test]
427    async fn test_start_workflow_enqueues_roots() {
428        let (scheduler, queue, _state) = setup().await;
429
430        scheduler.start_workflow("wf1").await.unwrap();
431
432        // Only job "a" (root) should be enqueued
433        let (job, _lease) = queue
434            .claim("w1", &[], std::time::Duration::from_secs(30))
435            .await
436            .unwrap()
437            .unwrap();
438        assert_eq!(job.job_id, "a");
439
440        // No more jobs available
441        assert!(
442            queue
443                .claim("w1", &[], std::time::Duration::from_secs(30))
444                .await
445                .unwrap()
446                .is_none()
447        );
448    }
449
450    #[tokio::test]
451    async fn test_completed_enqueues_downstream() {
452        let (scheduler, queue, state) = setup().await;
453
454        scheduler.start_workflow("wf1").await.unwrap();
455
456        // Claim and complete job A
457        let (_, lease) = queue
458            .claim("w1", &[], std::time::Duration::from_secs(30))
459            .await
460            .unwrap()
461            .unwrap();
462
463        // Process the Started event
464        scheduler
465            .handle_event(JobEvent::Started {
466                workflow_id: "wf1".into(),
467                job_id: "a".into(),
468                worker_id: "w1".into(),
469            })
470            .await
471            .unwrap();
472
473        // Complete job A
474        queue
475            .complete(&lease.lease_id, HashMap::new())
476            .await
477            .unwrap();
478
479        // Process the Completed event — should enqueue B and C
480        scheduler
481            .handle_event(JobEvent::Completed {
482                workflow_id: "wf1".into(),
483                job_id: "a".into(),
484                outputs: HashMap::new(),
485            })
486            .await
487            .unwrap();
488
489        // Both B and C should now be claimable
490        let (job1, _) = queue
491            .claim("w1", &[], std::time::Duration::from_secs(30))
492            .await
493            .unwrap()
494            .unwrap();
495        let (job2, _) = queue
496            .claim("w1", &[], std::time::Duration::from_secs(30))
497            .await
498            .unwrap()
499            .unwrap();
500
501        let mut ids = vec![job1.job_id, job2.job_id];
502        ids.sort();
503        assert_eq!(ids, vec!["b", "c"]);
504
505        // Check state
506        let s = state.read().await;
507        let wf = &s.workflows["wf1"];
508        assert_eq!(
509            wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
510            JobStatus::Success
511        );
512    }
513
514    #[tokio::test]
515    async fn test_failure_skips_downstream() {
516        let (scheduler, _queue, state) = setup().await;
517
518        scheduler.start_workflow("wf1").await.unwrap();
519
520        // Simulate job A failing permanently
521        scheduler
522            .handle_event(JobEvent::Failed {
523                workflow_id: "wf1".into(),
524                job_id: "a".into(),
525                error: "boom".into(),
526                retryable: false,
527            })
528            .await
529            .unwrap();
530
531        let s = state.read().await;
532        let wf = &s.workflows["wf1"];
533        assert_eq!(
534            wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
535            JobStatus::Failure
536        );
537        assert_eq!(
538            wf.jobs.iter().find(|j| j.id == "b").unwrap().status,
539            JobStatus::Skipped
540        );
541        assert_eq!(
542            wf.jobs.iter().find(|j| j.id == "c").unwrap().status,
543            JobStatus::Skipped
544        );
545    }
546
547    #[tokio::test]
548    async fn test_cancel_workflow() {
549        let (scheduler, _queue, state) = setup().await;
550
551        scheduler.start_workflow("wf1").await.unwrap();
552        scheduler.cancel_workflow("wf1").await.unwrap();
553
554        let s = state.read().await;
555        let wf = &s.workflows["wf1"];
556        for job in &wf.jobs {
557            assert!(
558                job.status == JobStatus::Cancelled,
559                "job {} should be cancelled, got {:?}",
560                job.id,
561                job.status
562            );
563        }
564    }
565}