1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use workflow_graph_shared::{JobStatus, Workflow};
7
8use crate::error::SchedulerError;
9use crate::traits::*;
10
11pub type SharedState = Arc<RwLock<WorkflowState>>;
13
14pub struct WorkflowState {
16 pub workflows: HashMap<String, Workflow>,
17}
18
19impl WorkflowState {
20 pub fn new() -> Self {
21 Self {
22 workflows: HashMap::new(),
23 }
24 }
25}
26
27impl Default for WorkflowState {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33pub struct DagScheduler<Q: JobQueue, A: ArtifactStore> {
38 queue: Arc<Q>,
39 artifacts: Arc<A>,
40 state: SharedState,
41}
42
43impl<Q: JobQueue, A: ArtifactStore> DagScheduler<Q, A> {
44 pub fn new(queue: Arc<Q>, artifacts: Arc<A>, state: SharedState) -> Self {
45 Self {
46 queue,
47 artifacts,
48 state,
49 }
50 }
51
52 pub async fn start_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
54 let root_jobs = {
55 let mut state = self.state.write().await;
56 let wf = state
57 .workflows
58 .get_mut(workflow_id)
59 .ok_or_else(|| SchedulerError::WorkflowNotFound(workflow_id.to_string()))?;
60
61 for job in &mut wf.jobs {
63 job.status = JobStatus::Queued;
64 job.duration_secs = None;
65 job.started_at = None;
66 job.output = None;
67 }
68
69 wf.jobs
71 .iter()
72 .filter(|j| j.depends_on.is_empty())
73 .map(|j| (j.id.clone(), j.command.clone()))
74 .collect::<Vec<_>>()
75 };
76
77 for (job_id, command) in root_jobs {
79 let queued = QueuedJob {
80 job_id,
81 workflow_id: workflow_id.to_string(),
82 command,
83 required_labels: vec![],
84 retry_policy: RetryPolicy::default(),
85 attempt: 0,
86 upstream_outputs: HashMap::new(),
87 enqueued_at_ms: now_ms(),
88 delayed_until_ms: 0,
89 };
90 self.queue.enqueue(queued).await?;
91 }
92
93 Ok(())
94 }
95
96 pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
98 self.queue.cancel_workflow(workflow_id).await?;
99
100 let mut state = self.state.write().await;
101 if let Some(wf) = state.workflows.get_mut(workflow_id) {
102 for job in &mut wf.jobs {
103 if job.status == JobStatus::Queued || job.status == JobStatus::Running {
104 job.status = JobStatus::Cancelled;
105 }
106 }
107 }
108 Ok(())
109 }
110
111 pub async fn run(self: Arc<Self>) {
114 let mut rx = self.queue.subscribe();
115
116 loop {
117 let event = match rx.recv().await {
118 Ok(event) => event,
119 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120 eprintln!("Scheduler lagged by {n} events, some jobs may need manual recovery");
121 continue;
122 }
123 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
124 eprintln!("Queue event channel closed, scheduler shutting down");
125 break;
126 }
127 };
128
129 if let Err(e) = self.handle_event(event).await {
130 eprintln!("Scheduler error: {e}");
131 }
132 }
133 }
134
135 async fn handle_event(&self, event: JobEvent) -> Result<(), SchedulerError> {
136 match event {
137 JobEvent::Started {
138 workflow_id,
139 job_id,
140 ..
141 } => {
142 self.on_job_started(&workflow_id, &job_id).await;
143 }
144 JobEvent::Completed {
145 workflow_id,
146 job_id,
147 outputs,
148 } => {
149 self.on_job_completed(&workflow_id, &job_id, outputs)
150 .await?;
151 }
152 JobEvent::Failed {
153 workflow_id,
154 job_id,
155 error,
156 retryable,
157 } => {
158 self.on_job_failed(&workflow_id, &job_id, &error, retryable)
159 .await;
160 }
161 JobEvent::LeaseExpired {
162 workflow_id,
163 job_id,
164 ..
165 } => {
166 self.on_lease_expired(&workflow_id, &job_id).await;
167 }
168 JobEvent::Cancelled {
169 workflow_id,
170 job_id,
171 } => {
172 self.on_job_cancelled(&workflow_id, &job_id).await;
173 }
174 JobEvent::Ready { .. } => {
175 }
177 }
178 Ok(())
179 }
180
181 async fn on_job_started(&self, workflow_id: &str, job_id: &str) {
182 let mut state = self.state.write().await;
183 if let Some(wf) = state.workflows.get_mut(workflow_id)
184 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
185 {
186 job.status = JobStatus::Running;
187 job.started_at = Some(now_ms() as f64);
188 }
189 }
190
191 async fn on_job_completed(
192 &self,
193 workflow_id: &str,
194 job_id: &str,
195 outputs: HashMap<String, String>,
196 ) -> Result<(), SchedulerError> {
197 self.artifacts
199 .put_outputs(workflow_id, job_id, outputs)
200 .await?;
201
202 let ready_jobs = {
204 let mut state = self.state.write().await;
205 let wf = match state.workflows.get_mut(workflow_id) {
206 Some(wf) => wf,
207 None => return Ok(()),
208 };
209
210 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
212 job.status = JobStatus::Success;
213 if let Some(started) = job.started_at {
214 job.duration_secs =
215 Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
216 }
217 }
218
219 let ready: Vec<(String, String, Vec<String>)> = wf
221 .jobs
222 .iter()
223 .filter(|j| {
224 j.status == JobStatus::Queued
225 && j.depends_on.contains(&job_id.to_string())
226 && j.depends_on.iter().all(|dep| {
227 wf.jobs
228 .iter()
229 .find(|dj| dj.id == *dep)
230 .is_some_and(|dj| dj.status == JobStatus::Success)
231 })
232 })
233 .map(|j| (j.id.clone(), j.command.clone(), j.depends_on.clone()))
234 .collect();
235
236 ready
237 };
238
239 for (next_id, command, deps) in ready_jobs {
241 let upstream_outputs = self
242 .artifacts
243 .get_upstream_outputs(workflow_id, &deps)
244 .await?;
245
246 let queued = QueuedJob {
247 job_id: next_id,
248 workflow_id: workflow_id.to_string(),
249 command,
250 required_labels: vec![],
251 retry_policy: RetryPolicy::default(),
252 attempt: 0,
253 upstream_outputs,
254 enqueued_at_ms: now_ms(),
255 delayed_until_ms: 0,
256 };
257 self.queue.enqueue(queued).await?;
258 }
259
260 Ok(())
261 }
262
263 async fn on_job_failed(&self, workflow_id: &str, job_id: &str, error: &str, retryable: bool) {
264 let mut state = self.state.write().await;
265 let Some(wf) = state.workflows.get_mut(workflow_id) else {
266 return;
267 };
268
269 if retryable {
270 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
272 job.status = JobStatus::Queued;
273 job.started_at = None;
274 }
275 } else {
276 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
278 job.status = JobStatus::Failure;
279 job.output = Some(error.to_string());
280 if let Some(started) = job.started_at {
281 job.duration_secs =
282 Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
283 }
284 }
285
286 let skip_ids = find_transitive_downstream(wf, job_id);
288 for skip_id in &skip_ids {
289 if let Some(j) = wf.jobs.iter_mut().find(|j| j.id == *skip_id) {
290 j.status = JobStatus::Skipped;
291 }
292 }
293 }
294 }
295
296 async fn on_lease_expired(&self, workflow_id: &str, job_id: &str) {
297 let mut state = self.state.write().await;
300 if let Some(wf) = state.workflows.get_mut(workflow_id)
301 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
302 {
303 job.status = JobStatus::Queued;
307 job.started_at = None;
308 }
309 }
310
311 async fn on_job_cancelled(&self, workflow_id: &str, job_id: &str) {
312 let mut state = self.state.write().await;
313 if let Some(wf) = state.workflows.get_mut(workflow_id)
314 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
315 {
316 job.status = JobStatus::Cancelled;
317 }
318 }
319}
320
321fn find_transitive_downstream(wf: &Workflow, job_id: &str) -> Vec<String> {
323 let mut result = Vec::new();
324 let mut stack = vec![job_id.to_string()];
325
326 while let Some(current) = stack.pop() {
327 for job in &wf.jobs {
328 if job.depends_on.contains(¤t) && !result.contains(&job.id) {
329 result.push(job.id.clone());
330 stack.push(job.id.clone());
331 }
332 }
333 }
334
335 result
336}
337
338fn now_ms() -> u64 {
339 std::time::SystemTime::now()
340 .duration_since(std::time::UNIX_EPOCH)
341 .unwrap_or_default()
342 .as_millis() as u64
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::memory::{InMemoryArtifactStore, InMemoryJobQueue};
349
350 fn sample_workflow() -> Workflow {
351 Workflow {
352 id: "wf1".into(),
353 name: "test".into(),
354 trigger: "manual".into(),
355 jobs: vec![
356 workflow_graph_shared::Job {
357 id: "a".into(),
358 name: "Job A".into(),
359 status: JobStatus::Queued,
360 command: "echo a".into(),
361 duration_secs: None,
362 started_at: None,
363 required_labels: vec![],
364 max_retries: 0,
365 attempt: 0,
366 depends_on: vec![],
367 output: None,
368 },
369 workflow_graph_shared::Job {
370 id: "b".into(),
371 name: "Job B".into(),
372 status: JobStatus::Queued,
373 command: "echo b".into(),
374 duration_secs: None,
375 started_at: None,
376 required_labels: vec![],
377 max_retries: 0,
378 attempt: 0,
379 depends_on: vec!["a".into()],
380 output: None,
381 },
382 workflow_graph_shared::Job {
383 id: "c".into(),
384 name: "Job C".into(),
385 status: JobStatus::Queued,
386 command: "echo c".into(),
387 duration_secs: None,
388 started_at: None,
389 required_labels: vec![],
390 max_retries: 0,
391 attempt: 0,
392 depends_on: vec!["a".into()],
393 output: None,
394 },
395 ],
396 }
397 }
398
399 async fn setup() -> (
400 Arc<DagScheduler<InMemoryJobQueue, InMemoryArtifactStore>>,
401 Arc<InMemoryJobQueue>,
402 SharedState,
403 ) {
404 let queue = Arc::new(InMemoryJobQueue::new());
405 let artifacts = Arc::new(InMemoryArtifactStore::new());
406 let state = Arc::new(RwLock::new(WorkflowState::new()));
407
408 state
409 .write()
410 .await
411 .workflows
412 .insert("wf1".into(), sample_workflow());
413
414 let scheduler = Arc::new(DagScheduler::new(
415 queue.clone(),
416 artifacts.clone(),
417 state.clone(),
418 ));
419
420 (scheduler, queue, state)
421 }
422
423 #[tokio::test]
424 async fn test_start_workflow_enqueues_roots() {
425 let (scheduler, queue, _state) = setup().await;
426
427 scheduler.start_workflow("wf1").await.unwrap();
428
429 let (job, _lease) = queue
431 .claim("w1", &[], std::time::Duration::from_secs(30))
432 .await
433 .unwrap()
434 .unwrap();
435 assert_eq!(job.job_id, "a");
436
437 assert!(
439 queue
440 .claim("w1", &[], std::time::Duration::from_secs(30))
441 .await
442 .unwrap()
443 .is_none()
444 );
445 }
446
447 #[tokio::test]
448 async fn test_completed_enqueues_downstream() {
449 let (scheduler, queue, state) = setup().await;
450
451 scheduler.start_workflow("wf1").await.unwrap();
452
453 let (_, lease) = queue
455 .claim("w1", &[], std::time::Duration::from_secs(30))
456 .await
457 .unwrap()
458 .unwrap();
459
460 scheduler
462 .handle_event(JobEvent::Started {
463 workflow_id: "wf1".into(),
464 job_id: "a".into(),
465 worker_id: "w1".into(),
466 })
467 .await
468 .unwrap();
469
470 queue
472 .complete(&lease.lease_id, HashMap::new())
473 .await
474 .unwrap();
475
476 scheduler
478 .handle_event(JobEvent::Completed {
479 workflow_id: "wf1".into(),
480 job_id: "a".into(),
481 outputs: HashMap::new(),
482 })
483 .await
484 .unwrap();
485
486 let (job1, _) = queue
488 .claim("w1", &[], std::time::Duration::from_secs(30))
489 .await
490 .unwrap()
491 .unwrap();
492 let (job2, _) = queue
493 .claim("w1", &[], std::time::Duration::from_secs(30))
494 .await
495 .unwrap()
496 .unwrap();
497
498 let mut ids = vec![job1.job_id, job2.job_id];
499 ids.sort();
500 assert_eq!(ids, vec!["b", "c"]);
501
502 let s = state.read().await;
504 let wf = &s.workflows["wf1"];
505 assert_eq!(
506 wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
507 JobStatus::Success
508 );
509 }
510
511 #[tokio::test]
512 async fn test_failure_skips_downstream() {
513 let (scheduler, _queue, state) = setup().await;
514
515 scheduler.start_workflow("wf1").await.unwrap();
516
517 scheduler
519 .handle_event(JobEvent::Failed {
520 workflow_id: "wf1".into(),
521 job_id: "a".into(),
522 error: "boom".into(),
523 retryable: false,
524 })
525 .await
526 .unwrap();
527
528 let s = state.read().await;
529 let wf = &s.workflows["wf1"];
530 assert_eq!(
531 wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
532 JobStatus::Failure
533 );
534 assert_eq!(
535 wf.jobs.iter().find(|j| j.id == "b").unwrap().status,
536 JobStatus::Skipped
537 );
538 assert_eq!(
539 wf.jobs.iter().find(|j| j.id == "c").unwrap().status,
540 JobStatus::Skipped
541 );
542 }
543
544 #[tokio::test]
545 async fn test_cancel_workflow() {
546 let (scheduler, _queue, state) = setup().await;
547
548 scheduler.start_workflow("wf1").await.unwrap();
549 scheduler.cancel_workflow("wf1").await.unwrap();
550
551 let s = state.read().await;
552 let wf = &s.workflows["wf1"];
553 for job in &wf.jobs {
554 assert!(
555 job.status == JobStatus::Cancelled,
556 "job {} should be cancelled, got {:?}",
557 job.id,
558 job.status
559 );
560 }
561 }
562}