1mod sweep;
7pub use sweep::state_sweep;
8
9mod subscriber;
10pub use subscriber::StateSubscriber;
11
12mod config;
13pub use config::StateConfig;
14
15use std::{
16 collections::{HashMap, VecDeque},
17 sync::Arc,
18 time::SystemTime,
19};
20
21use parking_lot::RwLock;
22use tracing::debug;
23
24use solti_model::{Slot, Task, TaskId, TaskPage, TaskPhase, TaskQuery, TaskRun, TaskSpec};
25
26#[derive(Clone)]
34pub struct TaskState {
35 inner: Arc<RwLock<TaskStateInner>>,
36}
37
38struct TaskStateInner {
39 tasks: HashMap<TaskId, Task>,
41 by_slot: HashMap<Slot, Vec<TaskId>>,
43 runs: HashMap<TaskId, VecDeque<TaskRun>>,
45}
46
47impl TaskState {
48 pub fn new() -> Self {
50 Self {
51 inner: Arc::new(RwLock::new(TaskStateInner {
52 by_slot: HashMap::new(),
53 tasks: HashMap::new(),
54 runs: HashMap::new(),
55 })),
56 }
57 }
58
59 pub fn add_task(&self, id: TaskId, spec: TaskSpec) {
61 let mut inner = self.inner.write();
62
63 let slot = spec.slot().clone();
64 let task = Task::new(id.clone(), spec);
65
66 inner.by_slot.entry(slot).or_default().push(id.clone());
67 inner.tasks.insert(id, task);
68 }
69
70 pub fn unregister_task(&self, id: &TaskId) {
75 let mut inner = self.inner.write();
76
77 if let Some(task) = inner.tasks.remove(id)
78 && let Some(ids) = inner.by_slot.get_mut(task.slot())
79 {
80 ids.retain(|task_id| task_id != id);
81 if ids.is_empty() {
82 inner.by_slot.remove(task.slot());
83 }
84 }
85 }
86
87 pub fn delete_task(&self, id: &TaskId) -> bool {
92 let mut inner = self.inner.write();
93 inner.runs.remove(id);
94
95 if let Some(task) = inner.tasks.remove(id) {
96 if let Some(ids) = inner.by_slot.get_mut(task.slot()) {
97 ids.retain(|task_id| task_id != id);
98 if ids.is_empty() {
99 inner.by_slot.remove(task.slot());
100 }
101 }
102 true
103 } else {
104 false
105 }
106 }
107
108 pub fn transition_starting(&self, id: &TaskId) -> Option<u32> {
110 let mut inner = self.inner.write();
111
112 let attempt = if let Some(task) = inner.tasks.get_mut(id) {
113 task.transition_starting();
114 task.status().attempt
115 } else {
116 return None;
117 };
118
119 let run = TaskRun::starting(attempt);
120 inner.runs.entry(id.clone()).or_default().push_back(run);
121
122 Some(attempt)
123 }
124
125 pub fn transition_finished(
127 &self,
128 id: &TaskId,
129 phase: TaskPhase,
130 error: Option<String>,
131 exit_code: Option<i32>,
132 ) -> bool {
133 let mut inner = self.inner.write();
134
135 let found = if let Some(task) = inner.tasks.get_mut(id) {
136 match task.transition_finished(phase, error.clone(), exit_code) {
137 Ok(()) => true,
138 Err(e) => {
139 tracing::warn!(task = %id, error = %e, "ignoring illegal transition");
140 return false;
141 }
142 }
143 } else {
144 false
145 };
146
147 if let Some(runs) = inner.runs.get_mut(id)
148 && let Some(run) = runs.back_mut().filter(|r| r.is_active())
149 {
150 run.finish(phase, error, exit_code);
151 }
152
153 found
154 }
155
156 pub fn list_runs(&self, id: &TaskId) -> Vec<TaskRun> {
158 let inner = self.inner.read();
159 inner
160 .runs
161 .get(id)
162 .map(|runs| runs.iter().cloned().collect())
163 .unwrap_or_default()
164 }
165
166 pub fn get(&self, id: &TaskId) -> Option<Task> {
168 let inner = self.inner.read();
169 inner.tasks.get(id).cloned()
170 }
171
172 pub fn list_by_slot(&self, slot: &str) -> Vec<Task> {
174 let inner = self.inner.read();
175
176 inner
177 .by_slot
178 .get(slot)
179 .map(|ids| {
180 ids.iter()
181 .filter_map(|id| inner.tasks.get(id).cloned())
182 .collect()
183 })
184 .unwrap_or_default()
185 }
186
187 pub fn list_all(&self) -> Vec<Task> {
189 let inner = self.inner.read();
190 inner.tasks.values().cloned().collect()
191 }
192
193 pub fn list_by_status(&self, phase: TaskPhase) -> Vec<Task> {
195 let inner = self.inner.read();
196 inner
197 .tasks
198 .values()
199 .filter(|task| task.status().phase == phase)
200 .cloned()
201 .collect()
202 }
203
204 pub fn sweep(&self, config: &StateConfig) -> (usize, usize) {
212 let mut inner = self.inner.write();
213 let now = SystemTime::now();
214 let mut runs_removed = 0usize;
215 let mut tasks_removed = 0usize;
216
217 for runs in inner.runs.values_mut() {
218 let before = runs.len();
219 runs.retain(|run| {
220 if let Some(finished) = run.finished_at {
221 now.duration_since(finished)
222 .map(|age| age < config.run_ttl)
223 .unwrap_or(true)
224 } else {
225 true
226 }
227 });
228 runs_removed += before - runs.len();
229 }
230 inner.runs.retain(|_, runs| !runs.is_empty());
231
232 let expired_tasks: Vec<TaskId> = inner
233 .tasks
234 .iter()
235 .filter(|(id, task)| {
236 task.status().phase.is_terminal()
237 && inner.runs.get(*id).is_none_or(|runs| runs.is_empty())
238 && now
239 .duration_since(task.metadata().updated_at)
240 .map(|age| age >= config.task_ttl)
241 .unwrap_or(false)
242 })
243 .map(|(id, _)| id.clone())
244 .collect();
245
246 for id in &expired_tasks {
247 if let Some(task) = inner.tasks.remove(id) {
248 if let Some(ids) = inner.by_slot.get_mut(task.slot()) {
249 ids.retain(|task_id| task_id != id);
250 if ids.is_empty() {
251 inner.by_slot.remove(task.slot());
252 }
253 }
254 tasks_removed += 1;
255 }
256 }
257 if runs_removed > 0 || tasks_removed > 0 {
258 debug!(runs_removed, tasks_removed, "state sweep completed");
259 }
260
261 (runs_removed, tasks_removed)
262 }
263
264 pub fn query(&self, q: &TaskQuery) -> TaskPage<Task> {
270 let inner = self.inner.read();
271
272 let iter: Box<dyn Iterator<Item = &Task>> = match q.slot() {
273 Some(slot) => {
274 let ids = inner.by_slot.get(slot.as_str());
275 match ids {
276 Some(ids) => Box::new(ids.iter().filter_map(|id| inner.tasks.get(id))),
277 None => {
278 return TaskPage {
279 items: vec![],
280 total: 0,
281 };
282 }
283 }
284 }
285 None => Box::new(inner.tasks.values()),
286 };
287
288 let iter: Box<dyn Iterator<Item = &Task>> = if q.status_filters().is_empty() {
289 iter
290 } else {
291 Box::new(iter.filter(|task| q.matches_phase(&task.status().phase)))
292 };
293
294 let mut filtered: Vec<&Task> = iter.collect();
295 filtered.sort_by(|a, b| a.metadata().id.cmp(&b.metadata().id));
296 let total = filtered.len();
297
298 let start = q.offset().min(total);
299 let items = filtered[start..]
300 .iter()
301 .take(q.limit())
302 .map(|task| (*task).clone())
303 .collect();
304
305 TaskPage { items, total }
306 }
307}
308
309impl Default for TaskState {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use solti_model::TaskKind;
319
320 fn default_spec_with_slot(slot: &str) -> TaskSpec {
321 TaskSpec::builder(slot, TaskKind::Embedded, 5_000_u64)
322 .build()
323 .expect("valid spec")
324 }
325
326 fn default_spec() -> TaskSpec {
327 default_spec_with_slot("slot")
328 }
329
330 #[test]
331 fn add_and_get_task() {
332 let state = TaskState::new();
333 let id = TaskId::from("task-1");
334
335 state.add_task(id.clone(), default_spec_with_slot("demo-slot"));
336
337 let task = state.get(&id).expect("task should exist");
338 assert_eq!(task.metadata().id, id);
339 assert_eq!(task.slot(), "demo-slot");
340 assert_eq!(task.status().phase, TaskPhase::Pending);
341 assert_eq!(task.status().attempt, 0);
342 }
343
344 #[test]
345 fn transition_starting_changes_phase_and_attempt() {
346 let state = TaskState::new();
347 let id = TaskId::from("task-1");
348
349 state.add_task(id.clone(), default_spec());
350 state.transition_starting(&id);
351
352 let task = state.get(&id).unwrap();
353 assert_eq!(task.status().phase, TaskPhase::Running);
354 assert!(task.status().error.is_none());
355 assert_eq!(task.status().attempt, 1);
356 }
357
358 #[test]
359 fn transition_finished_records_error() {
360 let state = TaskState::new();
361 let id = TaskId::from("task-1");
362
363 state.add_task(id.clone(), default_spec());
364 state.transition_starting(&id);
365 state.transition_finished(&id, TaskPhase::Failed, Some("timeout".to_string()), None);
366
367 let task = state.get(&id).unwrap();
368 assert_eq!(task.status().phase, TaskPhase::Failed);
369 assert_eq!(task.status().error.as_deref(), Some("timeout"));
370 }
371
372 #[test]
373 fn multiple_starts_increment_attempt() {
374 let state = TaskState::new();
375 let id = TaskId::from("task-1");
376
377 state.add_task(id.clone(), default_spec());
378 assert_eq!(state.transition_starting(&id), Some(1));
379 state.transition_finished(&id, TaskPhase::Failed, None, None);
380 assert_eq!(state.transition_starting(&id), Some(2));
381
382 let task = state.get(&id).unwrap();
383 assert_eq!(task.status().attempt, 2);
384 }
385
386 #[test]
387 fn unregister_task_removes_from_state() {
388 let state = TaskState::new();
389 let id = TaskId::from("task-1");
390
391 state.add_task(id.clone(), default_spec());
392 assert!(state.get(&id).is_some());
393
394 state.unregister_task(&id);
395 assert!(state.get(&id).is_none());
396 }
397
398 #[test]
399 fn list_by_slot_returns_correct_tasks() {
400 let state = TaskState::new();
401
402 state.add_task(TaskId::from("task-1"), default_spec_with_slot("slot-a"));
403 state.add_task(TaskId::from("task-2"), default_spec_with_slot("slot-a"));
404 state.add_task(TaskId::from("task-3"), default_spec_with_slot("slot-b"));
405
406 let slot_a_tasks = state.list_by_slot("slot-a");
407 assert_eq!(slot_a_tasks.len(), 2);
408
409 let slot_b_tasks = state.list_by_slot("slot-b");
410 assert_eq!(slot_b_tasks.len(), 1);
411 }
412
413 #[test]
414 fn list_by_status_filters_correctly() {
415 let state = TaskState::new();
416 let id1 = TaskId::from("task-1");
417 let id2 = TaskId::from("task-2");
418
419 state.add_task(id1.clone(), default_spec());
420 state.add_task(id2.clone(), default_spec());
421 state.transition_starting(&id1);
422
423 let running_tasks = state.list_by_status(TaskPhase::Running);
424 assert_eq!(running_tasks.len(), 1);
425 assert_eq!(running_tasks[0].metadata().id, id1);
426
427 let pending_tasks = state.list_by_status(TaskPhase::Pending);
428 assert_eq!(pending_tasks.len(), 1);
429 assert_eq!(pending_tasks[0].metadata().id, id2);
430 }
431
432 #[test]
433 fn list_all_returns_all_tasks() {
434 let state = TaskState::new();
435
436 state.add_task(TaskId::from("task-1"), default_spec_with_slot("slot-a"));
437 state.add_task(TaskId::from("task-2"), default_spec_with_slot("slot-b"));
438 state.add_task(TaskId::from("task-3"), default_spec_with_slot("slot-c"));
439
440 let all_tasks = state.list_all();
441 assert_eq!(all_tasks.len(), 3);
442 }
443
444 #[test]
445 fn transition_starting_creates_active_run() {
446 let state = TaskState::new();
447 let id = TaskId::from("task-1");
448
449 state.add_task(id.clone(), default_spec());
450 state.transition_starting(&id);
451
452 let runs = state.list_runs(&id);
453 assert_eq!(runs.len(), 1);
454 assert_eq!(runs[0].attempt, 1);
455 assert!(runs[0].is_active());
456 }
457
458 #[test]
459 fn transition_finished_closes_active_run() {
460 let state = TaskState::new();
461 let id = TaskId::from("task-1");
462
463 state.add_task(id.clone(), default_spec());
464 state.transition_starting(&id);
465 state.transition_finished(&id, TaskPhase::Succeeded, None, None);
466
467 let runs = state.list_runs(&id);
468 assert_eq!(runs.len(), 1);
469 assert!(!runs[0].is_active());
470 assert_eq!(runs[0].phase, TaskPhase::Succeeded);
471 }
472
473 #[test]
474 fn multiple_runs_ordered_by_attempt() {
475 let state = TaskState::new();
476 let id = TaskId::from("task-1");
477
478 state.add_task(id.clone(), default_spec());
479
480 state.transition_starting(&id);
482 state.transition_finished(&id, TaskPhase::Failed, Some("err".into()), None);
483
484 state.transition_starting(&id);
486 state.transition_finished(&id, TaskPhase::Succeeded, None, None);
487
488 let runs = state.list_runs(&id);
489 assert_eq!(runs.len(), 2);
490 assert_eq!(runs[0].attempt, 1);
491 assert_eq!(runs[0].phase, TaskPhase::Failed);
492 assert_eq!(runs[1].attempt, 2);
493 assert_eq!(runs[1].phase, TaskPhase::Succeeded);
494 }
495
496 #[test]
497 fn unregister_task_preserves_runs() {
498 let state = TaskState::new();
499 let id = TaskId::from("task-1");
500
501 state.add_task(id.clone(), default_spec());
502 state.transition_starting(&id);
503 state.transition_finished(&id, TaskPhase::Succeeded, None, None);
504
505 state.unregister_task(&id);
506
507 assert!(state.get(&id).is_none());
508 let runs = state.list_runs(&id);
510 assert_eq!(runs.len(), 1);
511 }
512
513 #[test]
514 fn list_runs_empty_for_unknown_task() {
515 let state = TaskState::new();
516 let runs = state.list_runs(&TaskId::from("nonexistent"));
517 assert!(runs.is_empty());
518 }
519
520 fn setup_query_state() -> TaskState {
521 let state = TaskState::new();
522 state.add_task(TaskId::from("a1"), default_spec_with_slot("slot-a"));
524 state.add_task(TaskId::from("a2"), default_spec_with_slot("slot-a"));
525 state.add_task(TaskId::from("a3"), default_spec_with_slot("slot-a"));
526 state.transition_starting(&TaskId::from("a1"));
527 state.transition_starting(&TaskId::from("a2"));
528
529 state.add_task(TaskId::from("b1"), default_spec_with_slot("slot-b"));
531 state.add_task(TaskId::from("b2"), default_spec_with_slot("slot-b"));
532 state.transition_starting(&TaskId::from("b1"));
533 state.transition_finished(
534 &TaskId::from("b1"),
535 TaskPhase::Failed,
536 Some("err".into()),
537 None,
538 );
539
540 state
541 }
542
543 #[test]
544 fn query_no_filters_returns_all() {
545 let state = setup_query_state();
546 let page = state.query(&TaskQuery::new().with_limit(100));
547 assert_eq!(page.total, 5);
548 assert_eq!(page.items.len(), 5);
549 }
550
551 #[test]
552 fn query_by_slot_only() {
553 let state = setup_query_state();
554 let page = state.query(&TaskQuery::new().with_slot("slot-a"));
555 assert_eq!(page.total, 3);
556 assert_eq!(page.items.len(), 3);
557 }
558
559 #[test]
560 fn query_by_status_only() {
561 let state = setup_query_state();
562 let page = state.query(&TaskQuery::new().with_status(TaskPhase::Running));
563 assert_eq!(page.total, 2);
564 assert_eq!(page.items.len(), 2);
565 }
566
567 #[test]
568 fn query_by_slot_and_status() {
569 let state = setup_query_state();
570 let page = state.query(
571 &TaskQuery::new()
572 .with_slot("slot-a")
573 .with_status(TaskPhase::Running),
574 );
575 assert_eq!(page.total, 2);
576 assert!(
577 page.items
578 .iter()
579 .all(|t| t.status().phase == TaskPhase::Running)
580 );
581 }
582
583 #[test]
584 fn query_by_slot_and_status_no_match() {
585 let state = setup_query_state();
586 let page = state.query(
587 &TaskQuery::new()
588 .with_slot("slot-b")
589 .with_status(TaskPhase::Running),
590 );
591 assert_eq!(page.total, 0);
592 assert!(page.items.is_empty());
593 }
594
595 #[test]
596 fn query_unknown_slot_returns_empty() {
597 let state = setup_query_state();
598 let page = state.query(&TaskQuery::new().with_slot("nonexistent"));
599 assert_eq!(page.total, 0);
600 assert!(page.items.is_empty());
601 }
602
603 #[test]
604 fn query_pagination_offset_and_limit() {
605 let state = setup_query_state();
606 let page = state.query(&TaskQuery::new().with_limit(2).with_offset(2));
608 assert_eq!(page.total, 5);
609 assert_eq!(page.items.len(), 2);
610 }
611
612 #[test]
613 fn query_offset_beyond_total() {
614 let state = setup_query_state();
615 let page = state.query(&TaskQuery::new().with_offset(100));
616 assert_eq!(page.total, 5);
617 assert!(page.items.is_empty());
618 }
619
620 #[test]
621 fn query_limit_larger_than_remaining() {
622 let state = setup_query_state();
623 let page = state.query(&TaskQuery::new().with_offset(3).with_limit(100));
625 assert_eq!(page.total, 5);
626 assert_eq!(page.items.len(), 2);
627 }
628
629 #[test]
630 fn sweep_removes_expired_runs() {
631 let state = TaskState::new();
632 let id = TaskId::from("task-1");
633
634 state.add_task(id.clone(), default_spec());
635 state.transition_starting(&id);
636 state.transition_finished(&id, TaskPhase::Succeeded, None, None);
637
638 let config = StateConfig {
640 run_ttl: std::time::Duration::ZERO,
641 task_ttl: std::time::Duration::from_secs(3600),
642 sweep_interval: std::time::Duration::from_secs(60),
643 };
644
645 let (runs_removed, tasks_removed) = state.sweep(&config);
646 assert_eq!(runs_removed, 1);
647 assert_eq!(tasks_removed, 0); assert!(state.list_runs(&id).is_empty());
649 }
650
651 #[test]
652 fn sweep_removes_terminal_tasks_without_runs() {
653 let state = TaskState::new();
654 let id = TaskId::from("task-1");
655
656 state.add_task(id.clone(), default_spec());
657 state.transition_starting(&id);
658 state.transition_finished(&id, TaskPhase::Succeeded, None, None);
659 let config = StateConfig {
662 run_ttl: std::time::Duration::ZERO,
663 task_ttl: std::time::Duration::ZERO,
664 sweep_interval: std::time::Duration::from_secs(60),
665 };
666
667 let (_, tasks_removed) = state.sweep(&config);
668 assert_eq!(tasks_removed, 1);
669 assert!(state.get(&id).is_none());
670 }
671
672 #[test]
673 fn sweep_keeps_active_runs() {
674 let state = TaskState::new();
675 let id = TaskId::from("task-1");
676
677 state.add_task(id.clone(), default_spec());
678 state.transition_starting(&id);
679 let config = StateConfig {
682 run_ttl: std::time::Duration::ZERO,
683 task_ttl: std::time::Duration::ZERO,
684 sweep_interval: std::time::Duration::from_secs(60),
685 };
686
687 let (runs_removed, _) = state.sweep(&config);
688 assert_eq!(runs_removed, 0);
689 assert_eq!(state.list_runs(&id).len(), 1);
690 }
691
692 #[test]
693 fn sweep_keeps_non_terminal_tasks() {
694 let state = TaskState::new();
695 let id = TaskId::from("task-1");
696
697 state.add_task(id.clone(), default_spec());
698 state.transition_starting(&id);
699
700 let config = StateConfig {
701 run_ttl: std::time::Duration::ZERO,
702 task_ttl: std::time::Duration::ZERO,
703 sweep_interval: std::time::Duration::from_secs(60),
704 };
705
706 let (_, tasks_removed) = state.sweep(&config);
707 assert_eq!(tasks_removed, 0);
708 assert!(state.get(&id).is_some());
709 }
710
711 #[test]
712 fn query_slot_with_pagination() {
713 let state = setup_query_state();
714 let page = state.query(
716 &TaskQuery::new()
717 .with_slot("slot-a")
718 .with_offset(1)
719 .with_limit(1),
720 );
721 assert_eq!(page.total, 3);
722 assert_eq!(page.items.len(), 1);
723 }
724
725 #[test]
726 fn transition_starting_atomically_updates_state() {
727 let state = TaskState::new();
728 let id = TaskId::from("task-1");
729
730 state.add_task(id.clone(), default_spec());
731
732 let attempt = state.transition_starting(&id);
733 assert_eq!(attempt, Some(1));
734
735 let task = state.get(&id).unwrap();
736 assert_eq!(task.status().phase, TaskPhase::Running);
737 assert_eq!(task.status().attempt, 1);
738
739 let runs = state.list_runs(&id);
740 assert_eq!(runs.len(), 1);
741 assert_eq!(runs[0].attempt, 1);
742 assert!(runs[0].is_active());
743 }
744
745 #[test]
746 fn transition_starting_returns_none_for_unknown_task() {
747 let state = TaskState::new();
748 assert_eq!(state.transition_starting(&TaskId::from("nope")), None);
749 }
750
751 #[test]
752 fn transition_finished_atomically_updates_state() {
753 let state = TaskState::new();
754 let id = TaskId::from("task-1");
755
756 state.add_task(id.clone(), default_spec());
757 state.transition_starting(&id);
758
759 state.transition_finished(&id, TaskPhase::Failed, Some("boom".into()), None);
760
761 let task = state.get(&id).unwrap();
762 assert_eq!(task.status().phase, TaskPhase::Failed);
763 assert_eq!(task.status().error.as_deref(), Some("boom"));
764
765 let runs = state.list_runs(&id);
766 assert_eq!(runs.len(), 1);
767 assert!(!runs[0].is_active());
768 assert_eq!(runs[0].phase, TaskPhase::Failed);
769 }
770
771 #[test]
772 fn transition_finished_success_no_error() {
773 let state = TaskState::new();
774 let id = TaskId::from("task-1");
775
776 state.add_task(id.clone(), default_spec());
777 state.transition_starting(&id);
778 state.transition_finished(&id, TaskPhase::Succeeded, None, None);
779
780 let task = state.get(&id).unwrap();
781 assert_eq!(task.status().phase, TaskPhase::Succeeded);
782 assert!(task.status().error.is_none());
783
784 let runs = state.list_runs(&id);
785 assert_eq!(runs[0].phase, TaskPhase::Succeeded);
786 assert!(!runs[0].is_active());
787 }
788
789 #[test]
790 fn transition_starting_multiple_attempts() {
791 let state = TaskState::new();
792 let id = TaskId::from("task-1");
793
794 state.add_task(id.clone(), default_spec());
795
796 assert_eq!(state.transition_starting(&id), Some(1));
798 state.transition_finished(&id, TaskPhase::Failed, Some("err".into()), None);
799
800 assert_eq!(state.transition_starting(&id), Some(2));
802 state.transition_finished(&id, TaskPhase::Succeeded, None, None);
803
804 let task = state.get(&id).unwrap();
805 assert_eq!(task.status().attempt, 2);
806 assert_eq!(task.status().phase, TaskPhase::Succeeded);
807
808 let runs = state.list_runs(&id);
809 assert_eq!(runs.len(), 2);
810 assert_eq!(runs[0].attempt, 1);
811 assert_eq!(runs[0].phase, TaskPhase::Failed);
812 assert_eq!(runs[1].attempt, 2);
813 assert_eq!(runs[1].phase, TaskPhase::Succeeded);
814 }
815}