1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use workflow_graph_shared::{JobStatus, Workflow};
7
8use crate::error::SchedulerError;
9use crate::traits::*;
10
11pub type SharedState = Arc<RwLock<WorkflowState>>;
13
14pub struct WorkflowState {
16 pub workflows: HashMap<String, Workflow>,
17}
18
19impl WorkflowState {
20 pub fn new() -> Self {
21 Self {
22 workflows: HashMap::new(),
23 }
24 }
25}
26
27impl Default for WorkflowState {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33pub struct DagScheduler<Q: JobQueue, A: ArtifactStore> {
38 queue: Arc<Q>,
39 artifacts: Arc<A>,
40 state: SharedState,
41}
42
43impl<Q: JobQueue, A: ArtifactStore> DagScheduler<Q, A> {
44 pub fn new(queue: Arc<Q>, artifacts: Arc<A>, state: SharedState) -> Self {
45 Self {
46 queue,
47 artifacts,
48 state,
49 }
50 }
51
52 pub async fn start_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
54 let root_jobs = {
55 let mut state = self.state.write().await;
56 let wf = state
57 .workflows
58 .get_mut(workflow_id)
59 .ok_or_else(|| SchedulerError::WorkflowNotFound(workflow_id.to_string()))?;
60
61 for job in &mut wf.jobs {
63 job.status = JobStatus::Queued;
64 job.duration_secs = None;
65 job.started_at = None;
66 job.output = None;
67 }
68
69 wf.jobs
71 .iter()
72 .filter(|j| j.depends_on.is_empty())
73 .map(|j| (j.id.clone(), j.command.clone()))
74 .collect::<Vec<_>>()
75 };
76
77 for (job_id, command) in root_jobs {
79 let queued = QueuedJob {
80 job_id,
81 workflow_id: workflow_id.to_string(),
82 command,
83 required_labels: vec![],
84 retry_policy: RetryPolicy::default(),
85 attempt: 0,
86 upstream_outputs: HashMap::new(),
87 enqueued_at_ms: now_ms(),
88 delayed_until_ms: 0,
89 };
90 self.queue.enqueue(queued).await?;
91 }
92
93 Ok(())
94 }
95
96 pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
98 self.queue.cancel_workflow(workflow_id).await?;
99
100 let mut state = self.state.write().await;
101 if let Some(wf) = state.workflows.get_mut(workflow_id) {
102 for job in &mut wf.jobs {
103 if job.status == JobStatus::Queued || job.status == JobStatus::Running {
104 job.status = JobStatus::Cancelled;
105 }
106 }
107 }
108 Ok(())
109 }
110
111 pub async fn run(self: Arc<Self>) {
114 let mut rx = self.queue.subscribe();
115
116 loop {
117 let event = match rx.recv().await {
118 Ok(event) => event,
119 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120 eprintln!("Scheduler lagged by {n} events, some jobs may need manual recovery");
121 continue;
122 }
123 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
124 eprintln!("Queue event channel closed, scheduler shutting down");
125 break;
126 }
127 };
128
129 if let Err(e) = self.handle_event(event).await {
130 eprintln!("Scheduler error: {e}");
131 }
132 }
133 }
134
135 async fn handle_event(&self, event: JobEvent) -> Result<(), SchedulerError> {
136 match event {
137 JobEvent::Started {
138 workflow_id,
139 job_id,
140 ..
141 } => {
142 self.on_job_started(&workflow_id, &job_id).await;
143 }
144 JobEvent::Completed {
145 workflow_id,
146 job_id,
147 outputs,
148 } => {
149 self.on_job_completed(&workflow_id, &job_id, outputs)
150 .await?;
151 }
152 JobEvent::Failed {
153 workflow_id,
154 job_id,
155 error,
156 retryable,
157 } => {
158 self.on_job_failed(&workflow_id, &job_id, &error, retryable)
159 .await;
160 }
161 JobEvent::LeaseExpired {
162 workflow_id,
163 job_id,
164 ..
165 } => {
166 self.on_lease_expired(&workflow_id, &job_id).await;
167 }
168 JobEvent::Cancelled {
169 workflow_id,
170 job_id,
171 } => {
172 self.on_job_cancelled(&workflow_id, &job_id).await;
173 }
174 JobEvent::Ready { .. } => {
175 }
177 }
178 Ok(())
179 }
180
181 async fn on_job_started(&self, workflow_id: &str, job_id: &str) {
182 let mut state = self.state.write().await;
183 if let Some(wf) = state.workflows.get_mut(workflow_id)
184 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
185 {
186 job.status = JobStatus::Running;
187 job.started_at = Some(now_ms() as f64);
188 }
189 }
190
191 async fn on_job_completed(
192 &self,
193 workflow_id: &str,
194 job_id: &str,
195 outputs: HashMap<String, String>,
196 ) -> Result<(), SchedulerError> {
197 self.artifacts
199 .put_outputs(workflow_id, job_id, outputs)
200 .await?;
201
202 let ready_jobs = {
204 let mut state = self.state.write().await;
205 let wf = match state.workflows.get_mut(workflow_id) {
206 Some(wf) => wf,
207 None => return Ok(()),
208 };
209
210 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
212 job.status = JobStatus::Success;
213 if let Some(started) = job.started_at {
214 job.duration_secs =
215 Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
216 }
217 }
218
219 let ready: Vec<(String, String, Vec<String>)> = wf
221 .jobs
222 .iter()
223 .filter(|j| {
224 j.status == JobStatus::Queued
225 && j.depends_on.contains(&job_id.to_string())
226 && j.depends_on.iter().all(|dep| {
227 wf.jobs
228 .iter()
229 .find(|dj| dj.id == *dep)
230 .is_some_and(|dj| dj.status == JobStatus::Success)
231 })
232 })
233 .map(|j| (j.id.clone(), j.command.clone(), j.depends_on.clone()))
234 .collect();
235
236 ready
237 };
238
239 for (next_id, command, deps) in ready_jobs {
241 let upstream_outputs = self
242 .artifacts
243 .get_upstream_outputs(workflow_id, &deps)
244 .await?;
245
246 let queued = QueuedJob {
247 job_id: next_id,
248 workflow_id: workflow_id.to_string(),
249 command,
250 required_labels: vec![],
251 retry_policy: RetryPolicy::default(),
252 attempt: 0,
253 upstream_outputs,
254 enqueued_at_ms: now_ms(),
255 delayed_until_ms: 0,
256 };
257 self.queue.enqueue(queued).await?;
258 }
259
260 Ok(())
261 }
262
263 async fn on_job_failed(&self, workflow_id: &str, job_id: &str, error: &str, retryable: bool) {
264 let mut state = self.state.write().await;
265 let Some(wf) = state.workflows.get_mut(workflow_id) else {
266 return;
267 };
268
269 if retryable {
270 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
272 job.status = JobStatus::Queued;
273 job.started_at = None;
274 }
275 } else {
276 if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
278 job.status = JobStatus::Failure;
279 job.output = Some(error.to_string());
280 if let Some(started) = job.started_at {
281 job.duration_secs =
282 Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
283 }
284 }
285
286 let skip_ids = find_transitive_downstream(wf, job_id);
288 for skip_id in &skip_ids {
289 if let Some(j) = wf.jobs.iter_mut().find(|j| j.id == *skip_id) {
290 j.status = JobStatus::Skipped;
291 }
292 }
293 }
294 }
295
296 async fn on_lease_expired(&self, workflow_id: &str, job_id: &str) {
297 let mut state = self.state.write().await;
300 if let Some(wf) = state.workflows.get_mut(workflow_id)
301 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
302 {
303 job.status = JobStatus::Queued;
307 job.started_at = None;
308 }
309 }
310
311 async fn on_job_cancelled(&self, workflow_id: &str, job_id: &str) {
312 let mut state = self.state.write().await;
313 if let Some(wf) = state.workflows.get_mut(workflow_id)
314 && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
315 {
316 job.status = JobStatus::Cancelled;
317 }
318 }
319}
320
321fn find_transitive_downstream(wf: &Workflow, job_id: &str) -> Vec<String> {
323 let mut result = Vec::new();
324 let mut stack = vec![job_id.to_string()];
325
326 while let Some(current) = stack.pop() {
327 for job in &wf.jobs {
328 if job.depends_on.contains(¤t) && !result.contains(&job.id) {
329 result.push(job.id.clone());
330 stack.push(job.id.clone());
331 }
332 }
333 }
334
335 result
336}
337
338fn now_ms() -> u64 {
339 std::time::SystemTime::now()
340 .duration_since(std::time::UNIX_EPOCH)
341 .unwrap_or_default()
342 .as_millis() as u64
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::memory::{InMemoryArtifactStore, InMemoryJobQueue};
349
350 fn sample_workflow() -> Workflow {
351 Workflow {
352 id: "wf1".into(),
353 name: "test".into(),
354 trigger: "manual".into(),
355 jobs: vec![
356 workflow_graph_shared::Job {
357 id: "a".into(),
358 name: "Job A".into(),
359 status: JobStatus::Queued,
360 command: "echo a".into(),
361 duration_secs: None,
362 started_at: None,
363 required_labels: vec![],
364 max_retries: 0,
365 attempt: 0,
366 depends_on: vec![],
367 output: None,
368 metadata: std::collections::HashMap::new(),
369 },
370 workflow_graph_shared::Job {
371 id: "b".into(),
372 name: "Job B".into(),
373 status: JobStatus::Queued,
374 command: "echo b".into(),
375 duration_secs: None,
376 started_at: None,
377 required_labels: vec![],
378 max_retries: 0,
379 attempt: 0,
380 depends_on: vec!["a".into()],
381 output: None,
382 metadata: std::collections::HashMap::new(),
383 },
384 workflow_graph_shared::Job {
385 id: "c".into(),
386 name: "Job C".into(),
387 status: JobStatus::Queued,
388 command: "echo c".into(),
389 duration_secs: None,
390 started_at: None,
391 required_labels: vec![],
392 max_retries: 0,
393 attempt: 0,
394 depends_on: vec!["a".into()],
395 output: None,
396 metadata: std::collections::HashMap::new(),
397 },
398 ],
399 }
400 }
401
402 async fn setup() -> (
403 Arc<DagScheduler<InMemoryJobQueue, InMemoryArtifactStore>>,
404 Arc<InMemoryJobQueue>,
405 SharedState,
406 ) {
407 let queue = Arc::new(InMemoryJobQueue::new());
408 let artifacts = Arc::new(InMemoryArtifactStore::new());
409 let state = Arc::new(RwLock::new(WorkflowState::new()));
410
411 state
412 .write()
413 .await
414 .workflows
415 .insert("wf1".into(), sample_workflow());
416
417 let scheduler = Arc::new(DagScheduler::new(
418 queue.clone(),
419 artifacts.clone(),
420 state.clone(),
421 ));
422
423 (scheduler, queue, state)
424 }
425
426 #[tokio::test]
427 async fn test_start_workflow_enqueues_roots() {
428 let (scheduler, queue, _state) = setup().await;
429
430 scheduler.start_workflow("wf1").await.unwrap();
431
432 let (job, _lease) = queue
434 .claim("w1", &[], std::time::Duration::from_secs(30))
435 .await
436 .unwrap()
437 .unwrap();
438 assert_eq!(job.job_id, "a");
439
440 assert!(
442 queue
443 .claim("w1", &[], std::time::Duration::from_secs(30))
444 .await
445 .unwrap()
446 .is_none()
447 );
448 }
449
450 #[tokio::test]
451 async fn test_completed_enqueues_downstream() {
452 let (scheduler, queue, state) = setup().await;
453
454 scheduler.start_workflow("wf1").await.unwrap();
455
456 let (_, lease) = queue
458 .claim("w1", &[], std::time::Duration::from_secs(30))
459 .await
460 .unwrap()
461 .unwrap();
462
463 scheduler
465 .handle_event(JobEvent::Started {
466 workflow_id: "wf1".into(),
467 job_id: "a".into(),
468 worker_id: "w1".into(),
469 })
470 .await
471 .unwrap();
472
473 queue
475 .complete(&lease.lease_id, HashMap::new())
476 .await
477 .unwrap();
478
479 scheduler
481 .handle_event(JobEvent::Completed {
482 workflow_id: "wf1".into(),
483 job_id: "a".into(),
484 outputs: HashMap::new(),
485 })
486 .await
487 .unwrap();
488
489 let (job1, _) = queue
491 .claim("w1", &[], std::time::Duration::from_secs(30))
492 .await
493 .unwrap()
494 .unwrap();
495 let (job2, _) = queue
496 .claim("w1", &[], std::time::Duration::from_secs(30))
497 .await
498 .unwrap()
499 .unwrap();
500
501 let mut ids = vec![job1.job_id, job2.job_id];
502 ids.sort();
503 assert_eq!(ids, vec!["b", "c"]);
504
505 let s = state.read().await;
507 let wf = &s.workflows["wf1"];
508 assert_eq!(
509 wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
510 JobStatus::Success
511 );
512 }
513
514 #[tokio::test]
515 async fn test_failure_skips_downstream() {
516 let (scheduler, _queue, state) = setup().await;
517
518 scheduler.start_workflow("wf1").await.unwrap();
519
520 scheduler
522 .handle_event(JobEvent::Failed {
523 workflow_id: "wf1".into(),
524 job_id: "a".into(),
525 error: "boom".into(),
526 retryable: false,
527 })
528 .await
529 .unwrap();
530
531 let s = state.read().await;
532 let wf = &s.workflows["wf1"];
533 assert_eq!(
534 wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
535 JobStatus::Failure
536 );
537 assert_eq!(
538 wf.jobs.iter().find(|j| j.id == "b").unwrap().status,
539 JobStatus::Skipped
540 );
541 assert_eq!(
542 wf.jobs.iter().find(|j| j.id == "c").unwrap().status,
543 JobStatus::Skipped
544 );
545 }
546
547 #[tokio::test]
548 async fn test_cancel_workflow() {
549 let (scheduler, _queue, state) = setup().await;
550
551 scheduler.start_workflow("wf1").await.unwrap();
552 scheduler.cancel_workflow("wf1").await.unwrap();
553
554 let s = state.read().await;
555 let wf = &s.workflows["wf1"];
556 for job in &wf.jobs {
557 assert!(
558 job.status == JobStatus::Cancelled,
559 "job {} should be cancelled, got {:?}",
560 job.id,
561 job.status
562 );
563 }
564 }
565}