1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use workflow_graph_shared::{JobStatus, Workflow};
7
8use crate::error::SchedulerError;
9use crate::traits::*;
10
11pub type SharedState = Arc<RwLock<WorkflowState>>;
13
14pub struct WorkflowState {
16 pub workflows: HashMap<String, Workflow>,
17}
18
19impl WorkflowState {
20 pub fn new() -> Self {
21 Self {
22 workflows: HashMap::new(),
23 }
24 }
25}
26
27impl Default for WorkflowState {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33pub struct DagScheduler<Q: JobQueue, A: ArtifactStore> {
38 queue: Arc<Q>,
39 artifacts: Arc<A>,
40 state: SharedState,
41}
42
43impl<Q: JobQueue, A: ArtifactStore> DagScheduler<Q, A> {
44 pub fn new(queue: Arc<Q>, artifacts: Arc<A>, state: SharedState) -> Self {
45 Self {
46 queue,
47 artifacts,
48 state,
49 }
50 }
51
52 pub async fn start_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
54 let root_jobs = {
55 let mut state = self.state.write().await;
56 let wf = state
57 .workflows
58 .get_mut(workflow_id)
59 .ok_or_else(|| SchedulerError::WorkflowNotFound(workflow_id.to_string()))?;
60
61 for job in &mut wf.jobs {
63 job.status = JobStatus::Queued;
64 job.duration_secs = None;
65 job.started_at = None;
66 job.output = None;
67 }
68
69 wf.jobs
71 .iter()
72 .filter(|j| j.depends_on.is_empty())
73 .map(|j| (j.id.clone(), j.command.clone()))
74 .collect::<Vec<_>>()
75 };
76
77 for (job_id, command) in root_jobs {
79 let queued = QueuedJob {
80 job_id,
81 workflow_id: workflow_id.to_string(),
82 command,
83 required_labels: vec![],
84 retry_policy: RetryPolicy::default(),
85 attempt: 0,
86 upstream_outputs: HashMap::new(),
87 enqueued_at_ms: now_ms(),
88 delayed_until_ms: 0,
89 };
90 self.queue.enqueue(queued).await?;
91 }
92
93 Ok(())
94 }
95
96 pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
98 self.queue.cancel_workflow(workflow_id).await?;
99
100 let mut state = self.state.write().await;
101 if let Some(wf) = state.workflows.get_mut(workflow_id) {
102 for job in &mut wf.jobs {
103 if job.status == JobStatus::Queued || job.status == JobStatus::Running {
104 job.status = JobStatus::Cancelled;
105 }
106 }
107 }
108 Ok(())
109 }
110
111 pub async fn run(self: Arc<Self>) {
114 let mut rx = self.queue.subscribe();
115
116 loop {
117 let event = match rx.recv().await {
118 Ok(event) => event,
119 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120 eprintln!("Scheduler lagged by {n} events, some jobs may need manual recovery");
121 continue;
122 }
123 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
124 eprintln!("Queue event channel closed, scheduler shutting down");
125 break;
126 }
127 };
128
129 if let Err(e) = self.handle_event(event).await {
130 eprintln!("Scheduler error: {e}");
131 }
132 }
133 }
134
135 async fn handle_event(&self, event: JobEvent) -> Result<(), SchedulerError> {
136 match event {
137 JobEvent::Started {
138 workflow_id,
139 job_id,
140 ..
141 } => {
142 self.on_job_started(&workflow_id, &job_id).await;
143 }
144 JobEvent::Completed {
145 workflow_id,
146 job_id,
147 outputs,
148 } => {
149 self.on_job_completed(&workflow_id, &job_id, outputs)
150 .await?;
151 }
152 JobEvent::Failed {
153 workflow_id,
154 job_id,
155 error,
156 retryable,
157 } => {
158 self.on_job_failed(&workflow_id, &job_id, &error, retryable)
159 .await;
160 }
161 JobEvent::LeaseExpired {
162 workflow_id,
163 job_id,
164 ..
165 } => {
166 self.on_lease_expired(&workflow_id, &job_id).await;
167 }
168 JobEvent::Cancelled {
169 workflow_id,
170 job_id,
171 } => {
172 self.on_job_cancelled(&workflow_id, &job_id).await;
173 }
174 JobEvent::Ready { .. } => {
175 }
177 }
178 Ok(())
179 }
180
181 async fn on_job_started(&self, workflow_id: &str, job_id: &str) {
182 let mut state = self.state.write().await;
183 if let Some(wf) = state.workflows.get_mut(workflow_id)
184 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
185 {
186 job.status = JobStatus::Running;
187 job.started_at = Some(now_ms() as f64);
188 }
189 }
190
191 async fn on_job_completed(
192 &self,
193 workflow_id: &str,
194 job_id: &str,
195 outputs: HashMap<String, String>,
196 ) -> Result<(), SchedulerError> {
197 self.artifacts
199 .put_outputs(workflow_id, job_id, outputs)
200 .await?;
201
202 let ready_jobs = {
204 let mut state = self.state.write().await;
205 let wf = match state.workflows.get_mut(workflow_id) {
206 Some(wf) => wf,
207 None => return Ok(()),
208 };
209
210 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
212 job.status = JobStatus::Success;
213 if let Some(started) = job.started_at {
214 job.duration_secs =
215 Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
216 }
217 }
218
219 let ready: Vec<(String, String, Vec<String>)> = wf
221 .jobs
222 .iter()
223 .filter(|j| {
224 j.status == JobStatus::Queued
225 && j.depends_on.contains(&job_id.to_string())
226 && j.depends_on.iter().all(|dep| {
227 wf.jobs
228 .iter()
229 .find(|dj| dj.id == *dep)
230 .is_some_and(|dj| dj.status == JobStatus::Success)
231 })
232 })
233 .map(|j| (j.id.clone(), j.command.clone(), j.depends_on.clone()))
234 .collect();
235
236 ready
237 };
238
239 for (next_id, command, deps) in ready_jobs {
241 let upstream_outputs = self
242 .artifacts
243 .get_upstream_outputs(workflow_id, &deps)
244 .await?;
245
246 let queued = QueuedJob {
247 job_id: next_id,
248 workflow_id: workflow_id.to_string(),
249 command,
250 required_labels: vec![],
251 retry_policy: RetryPolicy::default(),
252 attempt: 0,
253 upstream_outputs,
254 enqueued_at_ms: now_ms(),
255 delayed_until_ms: 0,
256 };
257 self.queue.enqueue(queued).await?;
258 }
259
260 Ok(())
261 }
262
263 async fn on_job_failed(&self, workflow_id: &str, job_id: &str, error: &str, retryable: bool) {
264 let mut state = self.state.write().await;
265 let Some(wf) = state.workflows.get_mut(workflow_id) else {
266 return;
267 };
268
269 if retryable {
270 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
272 job.status = JobStatus::Queued;
273 job.started_at = None;
274 }
275 } else {
276 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
278 job.status = JobStatus::Failure;
279 job.output = Some(error.to_string());
280 if let Some(started) = job.started_at {
281 job.duration_secs =
282 Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
283 }
284 }
285
286 let skip_ids = find_transitive_downstream(wf, job_id);
288 for skip_id in &skip_ids {
289 if let Some(j) = wf.jobs.iter_mut().find(|j| j.id == *skip_id) {
290 j.status = JobStatus::Skipped;
291 }
292 }
293 }
294 }
295
296 async fn on_lease_expired(&self, workflow_id: &str, job_id: &str) {
297 let mut state = self.state.write().await;
300 if let Some(wf) = state.workflows.get_mut(workflow_id)
301 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
302 {
303 job.status = JobStatus::Queued;
307 job.started_at = None;
308 }
309 }
310
311 async fn on_job_cancelled(&self, workflow_id: &str, job_id: &str) {
312 let mut state = self.state.write().await;
313 if let Some(wf) = state.workflows.get_mut(workflow_id)
314 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
315 {
316 job.status = JobStatus::Cancelled;
317 }
318 }
319}
320
321fn find_transitive_downstream(wf: &Workflow, job_id: &str) -> Vec<String> {
323 let mut result = Vec::new();
324 let mut stack = vec![job_id.to_string()];
325
326 while let Some(current) = stack.pop() {
327 for job in &wf.jobs {
328 if job.depends_on.contains(¤t) && !result.contains(&job.id) {
329 result.push(job.id.clone());
330 stack.push(job.id.clone());
331 }
332 }
333 }
334
335 result
336}
337
338fn now_ms() -> u64 {
339 std::time::SystemTime::now()
340 .duration_since(std::time::UNIX_EPOCH)
341 .unwrap_or_default()
342 .as_millis() as u64
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::memory::{InMemoryArtifactStore, InMemoryJobQueue};
349
350 fn sample_workflow() -> Workflow {
351 Workflow {
352 id: "wf1".into(),
353 name: "test".into(),
354 trigger: "manual".into(),
355 jobs: vec![
356 workflow_graph_shared::Job {
357 id: "a".into(),
358 name: "Job A".into(),
359 status: JobStatus::Queued,
360 command: "echo a".into(),
361 duration_secs: None,
362 started_at: None,
363 required_labels: vec![],
364 max_retries: 0,
365 attempt: 0,
366 depends_on: vec![],
367 output: None,
368 metadata: std::collections::HashMap::new(),
369 ports: vec![],
370 children: None,
371 collapsed: false,
372 },
373 workflow_graph_shared::Job {
374 id: "b".into(),
375 name: "Job B".into(),
376 status: JobStatus::Queued,
377 command: "echo b".into(),
378 duration_secs: None,
379 started_at: None,
380 required_labels: vec![],
381 max_retries: 0,
382 attempt: 0,
383 depends_on: vec!["a".into()],
384 output: None,
385 metadata: std::collections::HashMap::new(),
386 ports: vec![],
387 children: None,
388 collapsed: false,
389 },
390 workflow_graph_shared::Job {
391 id: "c".into(),
392 name: "Job C".into(),
393 status: JobStatus::Queued,
394 command: "echo c".into(),
395 duration_secs: None,
396 started_at: None,
397 required_labels: vec![],
398 max_retries: 0,
399 attempt: 0,
400 depends_on: vec!["a".into()],
401 output: None,
402 metadata: std::collections::HashMap::new(),
403 ports: vec![],
404 children: None,
405 collapsed: false,
406 },
407 ],
408 }
409 }
410
411 async fn setup() -> (
412 Arc<DagScheduler<InMemoryJobQueue, InMemoryArtifactStore>>,
413 Arc<InMemoryJobQueue>,
414 SharedState,
415 ) {
416 let queue = Arc::new(InMemoryJobQueue::new());
417 let artifacts = Arc::new(InMemoryArtifactStore::new());
418 let state = Arc::new(RwLock::new(WorkflowState::new()));
419
420 state
421 .write()
422 .await
423 .workflows
424 .insert("wf1".into(), sample_workflow());
425
426 let scheduler = Arc::new(DagScheduler::new(
427 queue.clone(),
428 artifacts.clone(),
429 state.clone(),
430 ));
431
432 (scheduler, queue, state)
433 }
434
435 #[tokio::test]
436 async fn test_start_workflow_enqueues_roots() {
437 let (scheduler, queue, _state) = setup().await;
438
439 scheduler.start_workflow("wf1").await.unwrap();
440
441 let (job, _lease) = queue
443 .claim("w1", &[], std::time::Duration::from_secs(30))
444 .await
445 .unwrap()
446 .unwrap();
447 assert_eq!(job.job_id, "a");
448
449 assert!(
451 queue
452 .claim("w1", &[], std::time::Duration::from_secs(30))
453 .await
454 .unwrap()
455 .is_none()
456 );
457 }
458
459 #[tokio::test]
460 async fn test_completed_enqueues_downstream() {
461 let (scheduler, queue, state) = setup().await;
462
463 scheduler.start_workflow("wf1").await.unwrap();
464
465 let (_, lease) = queue
467 .claim("w1", &[], std::time::Duration::from_secs(30))
468 .await
469 .unwrap()
470 .unwrap();
471
472 scheduler
474 .handle_event(JobEvent::Started {
475 workflow_id: "wf1".into(),
476 job_id: "a".into(),
477 worker_id: "w1".into(),
478 })
479 .await
480 .unwrap();
481
482 queue
484 .complete(&lease.lease_id, HashMap::new())
485 .await
486 .unwrap();
487
488 scheduler
490 .handle_event(JobEvent::Completed {
491 workflow_id: "wf1".into(),
492 job_id: "a".into(),
493 outputs: HashMap::new(),
494 })
495 .await
496 .unwrap();
497
498 let (job1, _) = queue
500 .claim("w1", &[], std::time::Duration::from_secs(30))
501 .await
502 .unwrap()
503 .unwrap();
504 let (job2, _) = queue
505 .claim("w1", &[], std::time::Duration::from_secs(30))
506 .await
507 .unwrap()
508 .unwrap();
509
510 let mut ids = vec![job1.job_id, job2.job_id];
511 ids.sort();
512 assert_eq!(ids, vec!["b", "c"]);
513
514 let s = state.read().await;
516 let wf = &s.workflows["wf1"];
517 assert_eq!(
518 wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
519 JobStatus::Success
520 );
521 }
522
523 #[tokio::test]
524 async fn test_failure_skips_downstream() {
525 let (scheduler, _queue, state) = setup().await;
526
527 scheduler.start_workflow("wf1").await.unwrap();
528
529 scheduler
531 .handle_event(JobEvent::Failed {
532 workflow_id: "wf1".into(),
533 job_id: "a".into(),
534 error: "boom".into(),
535 retryable: false,
536 })
537 .await
538 .unwrap();
539
540 let s = state.read().await;
541 let wf = &s.workflows["wf1"];
542 assert_eq!(
543 wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
544 JobStatus::Failure
545 );
546 assert_eq!(
547 wf.jobs.iter().find(|j| j.id == "b").unwrap().status,
548 JobStatus::Skipped
549 );
550 assert_eq!(
551 wf.jobs.iter().find(|j| j.id == "c").unwrap().status,
552 JobStatus::Skipped
553 );
554 }
555
556 #[tokio::test]
557 async fn test_cancel_workflow() {
558 let (scheduler, _queue, state) = setup().await;
559
560 scheduler.start_workflow("wf1").await.unwrap();
561 scheduler.cancel_workflow("wf1").await.unwrap();
562
563 let s = state.read().await;
564 let wf = &s.workflows["wf1"];
565 for job in &wf.jobs {
566 assert!(
567 job.status == JobStatus::Cancelled,
568 "job {} should be cancelled, got {:?}",
569 job.id,
570 job.status
571 );
572 }
573 }
574}