1use crate::backend::{BackendError, SignalStore, SnapshotStore, TaskClaimStore};
7use chrono::{Duration, Utc};
8use sayiir_core::snapshot::{
9 ExecutionPosition, PauseRequest, SignalKind, SignalRequest, WorkflowSnapshot,
10 WorkflowSnapshotState,
11};
12use sayiir_core::task_claim::{AvailableTask, TaskClaim};
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, RwLock};
15
16#[derive(Clone, Default)]
22pub struct InMemoryBackend {
23 snapshots: Arc<RwLock<HashMap<String, WorkflowSnapshot>>>,
24 claims: Arc<RwLock<HashMap<String, TaskClaim>>>, signals: Arc<RwLock<HashMap<String, HashMap<SignalKind, SignalRequest>>>>,
26 #[allow(clippy::type_complexity)]
28 events: Arc<RwLock<HashMap<(String, String), VecDeque<bytes::Bytes>>>>,
29}
30
31impl InMemoryBackend {
32 pub fn new() -> Self {
34 Default::default()
35 }
36
37 fn claim_key(instance_id: &str, task_id: &str) -> String {
38 format!("{}:{}", instance_id, task_id)
39 }
40
41 fn lock_error<E: std::fmt::Display>(e: E) -> BackendError {
43 BackendError::Backend(format!("Lock error: {e}"))
44 }
45}
46
47impl SnapshotStore for InMemoryBackend {
52 async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
53 let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
54 snapshots.insert(snapshot.instance_id.clone(), snapshot.clone());
55 Ok(())
56 }
57
58 async fn save_task_result(
59 &self,
60 instance_id: &str,
61 task_id: &str,
62 output: bytes::Bytes,
63 ) -> Result<(), BackendError> {
64 let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
65
66 let snapshot = snapshots
67 .get_mut(instance_id)
68 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
69
70 snapshot.mark_task_completed(task_id.to_string(), output);
71 Ok(())
72 }
73
74 async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
75 let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
76 snapshots
77 .get(instance_id)
78 .cloned()
79 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))
80 }
81
82 async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
83 let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
84 snapshots
85 .remove(instance_id)
86 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))
87 .map(|_| ())
88 }
89
90 async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
91 let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
92 Ok(snapshots.keys().cloned().collect())
93 }
94}
95
96impl SignalStore for InMemoryBackend {
101 async fn store_signal(
102 &self,
103 instance_id: &str,
104 kind: SignalKind,
105 request: SignalRequest,
106 ) -> Result<(), BackendError> {
107 {
109 let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
110 let snapshot = snapshots
111 .get(instance_id)
112 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
113
114 match kind {
115 SignalKind::Cancel => {
116 if snapshot.state.is_completed() {
117 return Err(BackendError::CannotCancel("Completed".to_string()));
118 }
119 if snapshot.state.is_failed() {
120 return Err(BackendError::CannotCancel("Failed".to_string()));
121 }
122 if snapshot.state.is_cancelled() {
123 return Ok(()); }
125 }
126 SignalKind::Pause => {
127 if snapshot.state.is_completed() {
128 return Err(BackendError::CannotPause("Completed".to_string()));
129 }
130 if snapshot.state.is_failed() {
131 return Err(BackendError::CannotPause("Failed".to_string()));
132 }
133 if snapshot.state.is_cancelled() {
134 return Err(BackendError::CannotPause("Cancelled".to_string()));
135 }
136 if snapshot.state.is_paused() {
137 return Ok(()); }
139 }
140 }
141 }
142
143 let mut signals = self.signals.write().map_err(Self::lock_error)?;
144 signals
145 .entry(instance_id.to_string())
146 .or_default()
147 .insert(kind, request);
148 Ok(())
149 }
150
151 async fn get_signal(
152 &self,
153 instance_id: &str,
154 kind: SignalKind,
155 ) -> Result<Option<SignalRequest>, BackendError> {
156 let signals = self.signals.read().map_err(Self::lock_error)?;
157 Ok(signals.get(instance_id).and_then(|m| m.get(&kind)).cloned())
158 }
159
160 async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
161 let mut signals = self.signals.write().map_err(Self::lock_error)?;
162 if let Some(inner) = signals.get_mut(instance_id) {
163 inner.remove(&kind);
164 if inner.is_empty() {
165 signals.remove(instance_id);
166 }
167 }
168 Ok(())
169 }
170
171 async fn send_event(
172 &self,
173 instance_id: &str,
174 signal_name: &str,
175 payload: bytes::Bytes,
176 ) -> Result<(), BackendError> {
177 let mut events = self.events.write().map_err(Self::lock_error)?;
178 events
179 .entry((instance_id.to_string(), signal_name.to_string()))
180 .or_default()
181 .push_back(payload);
182 Ok(())
183 }
184
185 async fn consume_event(
186 &self,
187 instance_id: &str,
188 signal_name: &str,
189 ) -> Result<Option<bytes::Bytes>, BackendError> {
190 let mut events = self.events.write().map_err(Self::lock_error)?;
191 let key = (instance_id.to_string(), signal_name.to_string());
192 let payload = events.get_mut(&key).and_then(VecDeque::pop_front);
193 if events.get(&key).is_some_and(VecDeque::is_empty) {
195 events.remove(&key);
196 }
197 Ok(payload)
198 }
199
200 async fn check_and_cancel(
202 &self,
203 instance_id: &str,
204 interrupted_at_task: Option<&str>,
205 ) -> Result<bool, BackendError> {
206 let request = {
207 let signals = self.signals.read().map_err(Self::lock_error)?;
208 match signals
209 .get(instance_id)
210 .and_then(|m| m.get(&SignalKind::Cancel))
211 {
212 Some(req) => req.clone(),
213 None => return Ok(false),
214 }
215 };
216
217 {
218 let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
219 let snapshot = snapshots
220 .get_mut(instance_id)
221 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
222 if !snapshot.state.is_in_progress() {
223 return Ok(false);
224 }
225 snapshot.mark_cancelled(
226 request.reason,
227 request.requested_by,
228 interrupted_at_task.map(String::from),
229 );
230 }
231
232 {
233 let mut signals = self.signals.write().map_err(Self::lock_error)?;
234 if let Some(inner) = signals.get_mut(instance_id) {
235 inner.remove(&SignalKind::Cancel);
236 if inner.is_empty() {
237 signals.remove(instance_id);
238 }
239 }
240 }
241
242 Ok(true)
243 }
244
245 async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
247 let request = {
248 let signals = self.signals.read().map_err(Self::lock_error)?;
249 match signals
250 .get(instance_id)
251 .and_then(|m| m.get(&SignalKind::Pause))
252 {
253 Some(req) => req.clone(),
254 None => return Ok(false),
255 }
256 };
257
258 {
259 let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
260 let snapshot = snapshots
261 .get_mut(instance_id)
262 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
263 if !snapshot.state.is_in_progress() {
264 return Ok(false);
265 }
266 let pause_request: PauseRequest = request.into();
267 snapshot.mark_paused(&pause_request);
268 }
269
270 {
271 let mut signals = self.signals.write().map_err(Self::lock_error)?;
272 if let Some(inner) = signals.get_mut(instance_id) {
273 inner.remove(&SignalKind::Pause);
274 if inner.is_empty() {
275 signals.remove(instance_id);
276 }
277 }
278 }
279
280 Ok(true)
281 }
282
283 async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
285 let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
286
287 let snapshot = snapshots
288 .get_mut(instance_id)
289 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
290
291 if !snapshot.state.is_paused() {
292 return Err(BackendError::CannotPause(format!(
293 "Workflow is not paused (current state: {:?})",
294 if snapshot.state.is_in_progress() {
295 "InProgress"
296 } else if snapshot.state.is_completed() {
297 "Completed"
298 } else if snapshot.state.is_failed() {
299 "Failed"
300 } else if snapshot.state.is_cancelled() {
301 "Cancelled"
302 } else {
303 "Unknown"
304 }
305 )));
306 }
307
308 snapshot.mark_unpaused();
309 Ok(snapshot.clone())
310 }
311}
312
313impl TaskClaimStore for InMemoryBackend {
318 async fn claim_task(
319 &self,
320 instance_id: &str,
321 task_id: &str,
322 worker_id: &str,
323 ttl: Option<Duration>,
324 ) -> Result<Option<TaskClaim>, BackendError> {
325 let key = Self::claim_key(instance_id, task_id);
326 let mut claims = self.claims.write().map_err(Self::lock_error)?;
327
328 if let Some(existing_claim) = claims.get(&key) {
330 if !existing_claim.is_expired() {
331 return Ok(None); }
333 claims.remove(&key);
335 }
336
337 let claim = TaskClaim::new(
339 instance_id.to_string(),
340 task_id.to_string(),
341 worker_id.to_string(),
342 ttl,
343 );
344 claims.insert(key, claim.clone());
345 Ok(Some(claim))
346 }
347
348 async fn release_task_claim(
349 &self,
350 instance_id: &str,
351 task_id: &str,
352 worker_id: &str,
353 ) -> Result<(), BackendError> {
354 let key = Self::claim_key(instance_id, task_id);
355 let mut claims = self.claims.write().map_err(Self::lock_error)?;
356
357 if let Some(claim) = claims.get(&key) {
358 if claim.worker_id != worker_id {
359 return Err(BackendError::Backend(format!(
360 "Claim owned by different worker: {}",
361 claim.worker_id
362 )));
363 }
364 claims.remove(&key);
365 Ok(())
366 } else {
367 Err(BackendError::NotFound(format!(
368 "{}:{}",
369 instance_id, task_id
370 )))
371 }
372 }
373
374 async fn extend_task_claim(
375 &self,
376 instance_id: &str,
377 task_id: &str,
378 worker_id: &str,
379 additional_duration: Duration,
380 ) -> Result<(), BackendError> {
381 let key = Self::claim_key(instance_id, task_id);
382 let mut claims = self.claims.write().map_err(Self::lock_error)?;
383
384 if let Some(claim) = claims.get_mut(&key) {
385 if claim.worker_id != worker_id {
386 return Err(BackendError::Backend(format!(
387 "Claim owned by different worker: {}",
388 claim.worker_id
389 )));
390 }
391
392 if let Some(expires_at) = claim.expires_at {
393 let expires_datetime = chrono::DateTime::from_timestamp(expires_at as i64, 0)
394 .ok_or_else(|| BackendError::Backend("Invalid timestamp".to_string()))?;
395 let new_expiry = expires_datetime
396 .checked_add_signed(additional_duration)
397 .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
398 claim.expires_at = Some(new_expiry.timestamp() as u64);
399 }
400 Ok(())
401 } else {
402 Err(BackendError::NotFound(format!(
403 "{}:{}",
404 instance_id, task_id
405 )))
406 }
407 }
408
409 async fn find_available_tasks(
410 &self,
411 worker_id: &str,
412 limit: usize,
413 ) -> Result<Vec<AvailableTask>, BackendError> {
414 {
416 let mut claims = self.claims.write().map_err(Self::lock_error)?;
417 claims.retain(|_, claim| !claim.is_expired());
418 }
419
420 let mut delay_advances: Vec<(String, String)> = Vec::new();
422 let mut delay_completions: Vec<(String, String)> = Vec::new();
423 let mut signal_advances: Vec<(String, String, Option<String>)> = Vec::new();
425 let mut signal_timeout_advances: Vec<(String, String, Option<String>)> = Vec::new();
427
428 {
429 let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
430 let signals = self.signals.read().map_err(Self::lock_error)?;
431 let events = self.events.read().map_err(Self::lock_error)?;
432
433 for (instance_id, snapshot) in snapshots.iter() {
434 if !snapshot.state.is_in_progress() {
435 continue;
436 }
437 if signals
438 .get(instance_id.as_str())
439 .is_some_and(|m| m.contains_key(&SignalKind::Cancel))
440 {
441 continue;
442 }
443 if signals
444 .get(instance_id.as_str())
445 .is_some_and(|m| m.contains_key(&SignalKind::Pause))
446 {
447 continue;
448 }
449 match &snapshot.state {
450 WorkflowSnapshotState::InProgress {
451 position:
452 ExecutionPosition::AtDelay {
453 wake_at,
454 delay_id,
455 next_task_id,
456 ..
457 },
458 ..
459 } if Utc::now() >= *wake_at => {
460 if let Some(next_id) = next_task_id {
461 delay_advances.push((instance_id.clone(), next_id.clone()));
462 } else {
463 delay_completions.push((instance_id.clone(), delay_id.clone()));
464 }
465 }
466 WorkflowSnapshotState::InProgress {
467 position:
468 ExecutionPosition::AtSignal {
469 signal_name,
470 wake_at,
471 next_task_id,
472 ..
473 },
474 ..
475 } => {
476 let key = (instance_id.clone(), signal_name.clone());
477 if events.get(&key).is_some_and(|q| !q.is_empty()) {
478 signal_advances.push((
480 instance_id.clone(),
481 signal_name.clone(),
482 next_task_id.clone(),
483 ));
484 } else if wake_at.is_some_and(|wt| Utc::now() >= wt) {
485 signal_timeout_advances.push((
487 instance_id.clone(),
488 signal_name.clone(),
489 next_task_id.clone(),
490 ));
491 }
492 }
493 _ => {}
494 }
495 }
496 }
497
498 if !delay_advances.is_empty()
500 || !delay_completions.is_empty()
501 || !signal_advances.is_empty()
502 || !signal_timeout_advances.is_empty()
503 {
504 let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
505 for (instance_id, next_task_id) in &delay_advances {
506 if let Some(snapshot) = snapshots.get_mut(instance_id) {
507 snapshot.update_position(ExecutionPosition::AtTask {
508 task_id: next_task_id.clone(),
509 });
510 }
511 }
512 for (instance_id, delay_id) in &delay_completions {
513 if let Some(snapshot) = snapshots.get_mut(instance_id) {
514 let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
515 snapshot.mark_completed(output);
516 }
517 }
518 {
520 let mut events = self.events.write().map_err(Self::lock_error)?;
521 for (instance_id, signal_name, next_task_id) in &signal_advances {
522 let key = (instance_id.clone(), signal_name.clone());
523 let payload = events
524 .get_mut(&key)
525 .and_then(VecDeque::pop_front)
526 .unwrap_or_default();
527 if events.get(&key).is_some_and(VecDeque::is_empty) {
529 events.remove(&key);
530 }
531 if let Some(snapshot) = snapshots.get_mut(instance_id) {
532 snapshot.mark_task_completed(signal_name.clone(), payload);
534 if let Some(next_id) = next_task_id {
535 snapshot.update_position(ExecutionPosition::AtTask {
536 task_id: next_id.clone(),
537 });
538 } else {
539 let output = snapshot
540 .get_task_result_bytes(signal_name)
541 .unwrap_or_default();
542 snapshot.mark_completed(output);
543 }
544 }
545 }
546 }
547 for (instance_id, signal_name, next_task_id) in &signal_timeout_advances {
549 if let Some(snapshot) = snapshots.get_mut(instance_id) {
550 snapshot.mark_task_completed(signal_name.clone(), bytes::Bytes::new());
551 if let Some(next_id) = next_task_id {
552 snapshot.update_position(ExecutionPosition::AtTask {
553 task_id: next_id.clone(),
554 });
555 } else {
556 snapshot.mark_completed(bytes::Bytes::new());
557 }
558 }
559 }
560 }
561
562 let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
563 let claims = self.claims.read().map_err(Self::lock_error)?;
564 let signals = self.signals.read().map_err(Self::lock_error)?;
565
566 let mut available = Vec::new();
567
568 for (instance_id, snapshot) in snapshots.iter() {
569 if !snapshot.state.is_in_progress() {
570 continue;
571 }
572
573 if let Some(instance_signals) = signals.get(instance_id.as_str())
575 && (instance_signals.contains_key(&SignalKind::Cancel)
576 || instance_signals.contains_key(&SignalKind::Pause))
577 {
578 continue;
579 }
580
581 if let WorkflowSnapshotState::InProgress {
582 completed_tasks,
583 position: ExecutionPosition::AtTask { task_id },
584 ..
585 } = &snapshot.state
586 {
587 let claim_key = Self::claim_key(instance_id, task_id);
588 let is_claimed = claims.contains_key(&claim_key);
589 let is_completed = completed_tasks.contains_key(task_id);
590
591 if !is_completed && !is_claimed {
592 if let Some(rs) = snapshot.task_retries.get(task_id)
594 && Utc::now() < rs.next_retry_at
595 {
596 continue;
597 }
598
599 let input = if completed_tasks.is_empty() {
600 snapshot.initial_input_bytes()
601 } else {
602 snapshot.get_last_task_output()
603 };
604
605 if let Some(input_bytes) = input {
606 available.push(AvailableTask {
607 instance_id: instance_id.clone(),
608 task_id: task_id.clone(),
609 input: input_bytes,
610 workflow_definition_hash: snapshot.definition_hash.clone(),
611 });
612
613 if available.len() >= limit {
614 break;
615 }
616 }
617 }
618 }
619 }
620
621 available.sort_by_key(|t| {
624 snapshots
625 .get(&t.instance_id)
626 .and_then(|s| s.task_retries.get(&t.task_id))
627 .and_then(|rs| rs.last_failed_worker.as_deref())
628 .is_some_and(|w| w == worker_id)
629 });
630
631 Ok(available)
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::backend::SignalStore;
639 use sayiir_core::snapshot::SignalKind;
640
641 #[tokio::test]
642 async fn test_save_and_load() {
643 let backend = InMemoryBackend::new();
644 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
645
646 backend.save_snapshot(&snapshot).await.unwrap();
647 let loaded = backend.load_snapshot("test-123").await.unwrap();
648
649 assert_eq!(snapshot.instance_id, loaded.instance_id);
650 assert_eq!(snapshot.definition_hash, loaded.definition_hash);
651 }
652
653 #[tokio::test]
654 async fn test_not_found() {
655 let backend = InMemoryBackend::new();
656 let result = backend.load_snapshot("nonexistent").await;
657 assert!(matches!(result, Err(BackendError::NotFound(_))));
658 }
659
660 #[tokio::test]
661 async fn test_delete() {
662 let backend = InMemoryBackend::new();
663 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
664
665 backend.save_snapshot(&snapshot).await.unwrap();
666 backend.delete_snapshot("test-123").await.unwrap();
667
668 let result = backend.load_snapshot("test-123").await;
669 assert!(matches!(result, Err(BackendError::NotFound(_))));
670 }
671
672 #[tokio::test]
673 async fn test_list_snapshots() {
674 let backend = InMemoryBackend::new();
675
676 backend
677 .save_snapshot(&WorkflowSnapshot::new(
678 "test-1".to_string(),
679 "hash-1".to_string(),
680 ))
681 .await
682 .unwrap();
683 backend
684 .save_snapshot(&WorkflowSnapshot::new(
685 "test-2".to_string(),
686 "hash-2".to_string(),
687 ))
688 .await
689 .unwrap();
690
691 let list = backend.list_snapshots().await.unwrap();
692 assert_eq!(list.len(), 2);
693 assert!(list.contains(&"test-1".to_string()));
694 assert!(list.contains(&"test-2".to_string()));
695 }
696
697 #[tokio::test]
700 async fn test_claim_task_success() {
701 let backend = InMemoryBackend::new();
702
703 let claim = backend
704 .claim_task(
705 "workflow-1",
706 "task-1",
707 "worker-1",
708 Some(Duration::seconds(300)),
709 )
710 .await
711 .unwrap();
712
713 assert!(claim.is_some());
714 let claim = claim.unwrap();
715 assert_eq!(claim.instance_id, "workflow-1");
716 assert_eq!(claim.task_id, "task-1");
717 assert_eq!(claim.worker_id, "worker-1");
718 assert!(claim.expires_at.is_some());
719 }
720
721 #[tokio::test]
722 async fn test_claim_task_already_claimed() {
723 let backend = InMemoryBackend::new();
724
725 let claim1 = backend
727 .claim_task(
728 "workflow-1",
729 "task-1",
730 "worker-1",
731 Some(Duration::seconds(300)),
732 )
733 .await
734 .unwrap();
735 assert!(claim1.is_some());
736
737 let claim2 = backend
739 .claim_task(
740 "workflow-1",
741 "task-1",
742 "worker-2",
743 Some(Duration::seconds(300)),
744 )
745 .await
746 .unwrap();
747 assert!(claim2.is_none());
748 }
749
750 #[tokio::test]
751 async fn test_claim_task_expired_claim_replaced() {
752 let backend = InMemoryBackend::new();
753
754 let claim1 = backend
756 .claim_task(
757 "workflow-1",
758 "task-1",
759 "worker-1",
760 Some(Duration::seconds(0)),
761 )
762 .await
763 .unwrap();
764 assert!(claim1.is_some());
765
766 let claim2 = backend
768 .claim_task(
769 "workflow-1",
770 "task-1",
771 "worker-2",
772 Some(Duration::seconds(300)),
773 )
774 .await
775 .unwrap();
776 assert!(claim2.is_some());
777 let claim2 = claim2.unwrap();
778 assert_eq!(claim2.worker_id, "worker-2");
779 }
780
781 #[tokio::test]
782 async fn test_claim_task_no_ttl() {
783 let backend = InMemoryBackend::new();
784
785 let claim = backend
786 .claim_task("workflow-1", "task-1", "worker-1", None)
787 .await
788 .unwrap();
789
790 assert!(claim.is_some());
791 let claim = claim.unwrap();
792 assert!(claim.expires_at.is_none());
793 assert!(!claim.is_expired()); }
795
796 #[tokio::test]
797 async fn test_release_task_claim_success() {
798 let backend = InMemoryBackend::new();
799
800 backend
802 .claim_task(
803 "workflow-1",
804 "task-1",
805 "worker-1",
806 Some(Duration::seconds(300)),
807 )
808 .await
809 .unwrap();
810
811 let result = backend
813 .release_task_claim("workflow-1", "task-1", "worker-1")
814 .await;
815 assert!(result.is_ok());
816
817 let claim = backend
819 .claim_task(
820 "workflow-1",
821 "task-1",
822 "worker-2",
823 Some(Duration::seconds(300)),
824 )
825 .await
826 .unwrap();
827 assert!(claim.is_some());
828 }
829
830 #[tokio::test]
831 async fn test_release_task_claim_wrong_worker() {
832 let backend = InMemoryBackend::new();
833
834 backend
836 .claim_task(
837 "workflow-1",
838 "task-1",
839 "worker-1",
840 Some(Duration::seconds(300)),
841 )
842 .await
843 .unwrap();
844
845 let result = backend
847 .release_task_claim("workflow-1", "task-1", "worker-2")
848 .await;
849 assert!(matches!(result, Err(BackendError::Backend(_))));
850 }
851
852 #[tokio::test]
853 async fn test_release_task_claim_not_found() {
854 let backend = InMemoryBackend::new();
855
856 let result = backend
857 .release_task_claim("workflow-1", "task-1", "worker-1")
858 .await;
859 assert!(matches!(result, Err(BackendError::NotFound(_))));
860 }
861
862 #[tokio::test]
863 async fn test_extend_task_claim_success() {
864 let backend = InMemoryBackend::new();
865
866 let claim = backend
868 .claim_task(
869 "workflow-1",
870 "task-1",
871 "worker-1",
872 Some(Duration::seconds(10)),
873 )
874 .await
875 .unwrap()
876 .unwrap();
877 let original_expiry = claim.expires_at.unwrap();
878
879 backend
881 .extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
882 .await
883 .unwrap();
884
885 let claims = backend.claims.read().unwrap();
887 let key = InMemoryBackend::claim_key("workflow-1", "task-1");
888 let extended_claim = claims.get(&key).unwrap();
889 assert!(extended_claim.expires_at.unwrap() > original_expiry);
890 }
891
892 #[tokio::test]
893 async fn test_extend_task_claim_wrong_worker() {
894 let backend = InMemoryBackend::new();
895
896 backend
898 .claim_task(
899 "workflow-1",
900 "task-1",
901 "worker-1",
902 Some(Duration::seconds(300)),
903 )
904 .await
905 .unwrap();
906
907 let result = backend
909 .extend_task_claim("workflow-1", "task-1", "worker-2", Duration::seconds(300))
910 .await;
911 assert!(matches!(result, Err(BackendError::Backend(_))));
912 }
913
914 #[tokio::test]
915 async fn test_extend_task_claim_not_found() {
916 let backend = InMemoryBackend::new();
917
918 let result = backend
919 .extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
920 .await;
921 assert!(matches!(result, Err(BackendError::NotFound(_))));
922 }
923
924 #[tokio::test]
925 async fn test_extend_task_claim_no_expiry() {
926 let backend = InMemoryBackend::new();
927
928 backend
930 .claim_task("workflow-1", "task-1", "worker-1", None)
931 .await
932 .unwrap();
933
934 backend
936 .extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
937 .await
938 .unwrap();
939
940 let claims = backend.claims.read().unwrap();
941 let key = InMemoryBackend::claim_key("workflow-1", "task-1");
942 let claim = claims.get(&key).unwrap();
943 assert!(claim.expires_at.is_none());
944 }
945
946 #[tokio::test]
947 async fn test_store_signal_cancel_success() {
948 let backend = InMemoryBackend::new();
949 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
950 backend.save_snapshot(&snapshot).await.unwrap();
951
952 let result = backend
953 .store_signal(
954 "test-123",
955 SignalKind::Cancel,
956 SignalRequest::new(
957 Some("User requested".to_string()),
958 Some("admin".to_string()),
959 ),
960 )
961 .await;
962 assert!(result.is_ok(), "store_signal should succeed");
963
964 let stored = backend
965 .get_signal("test-123", SignalKind::Cancel)
966 .await
967 .unwrap();
968 assert!(stored.is_some(), "cancel signal should be stored");
969 let stored = stored.unwrap();
970 assert_eq!(stored.reason, Some("User requested".to_string()));
971 assert_eq!(stored.requested_by, Some("admin".to_string()));
972 }
973
974 #[tokio::test]
975 async fn test_store_signal_cancel_not_found() {
976 let backend = InMemoryBackend::new();
977
978 let result = backend
979 .store_signal(
980 "nonexistent",
981 SignalKind::Cancel,
982 SignalRequest::new(None, None),
983 )
984 .await;
985 assert!(
986 matches!(result, Err(BackendError::NotFound(_))),
987 "should return NotFound for non-existent workflow"
988 );
989 }
990
991 #[tokio::test]
992 async fn test_store_signal_cancel_completed_workflow() {
993 let backend = InMemoryBackend::new();
994 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
995 snapshot.mark_completed(bytes::Bytes::from("result"));
996 backend.save_snapshot(&snapshot).await.unwrap();
997
998 let result = backend
999 .store_signal(
1000 "test-123",
1001 SignalKind::Cancel,
1002 SignalRequest::new(None, None),
1003 )
1004 .await;
1005 assert!(
1006 matches!(result, Err(BackendError::CannotCancel(_))),
1007 "should return CannotCancel for completed workflow"
1008 );
1009 }
1010
1011 #[tokio::test]
1012 async fn test_store_signal_cancel_failed_workflow() {
1013 let backend = InMemoryBackend::new();
1014 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1015 snapshot.mark_failed("Some error".to_string());
1016 backend.save_snapshot(&snapshot).await.unwrap();
1017
1018 let result = backend
1019 .store_signal(
1020 "test-123",
1021 SignalKind::Cancel,
1022 SignalRequest::new(None, None),
1023 )
1024 .await;
1025 assert!(
1026 matches!(result, Err(BackendError::CannotCancel(_))),
1027 "should return CannotCancel for failed workflow"
1028 );
1029 }
1030
1031 #[tokio::test]
1032 async fn test_store_signal_cancel_already_cancelled_idempotent() {
1033 let backend = InMemoryBackend::new();
1034 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1035 snapshot.mark_cancelled(Some("First cancel".to_string()), None, None);
1036 backend.save_snapshot(&snapshot).await.unwrap();
1037
1038 let result = backend
1039 .store_signal(
1040 "test-123",
1041 SignalKind::Cancel,
1042 SignalRequest::new(Some("Second cancel".to_string()), None),
1043 )
1044 .await;
1045 assert!(
1046 result.is_ok(),
1047 "cancelling already-cancelled workflow should be idempotent"
1048 );
1049 }
1050
1051 #[tokio::test]
1052 async fn test_get_signal_cancel_none() {
1053 let backend = InMemoryBackend::new();
1054
1055 let result = backend
1056 .get_signal("test-123", SignalKind::Cancel)
1057 .await
1058 .unwrap();
1059 assert!(
1060 result.is_none(),
1061 "should return None when no cancellation signal exists"
1062 );
1063 }
1064
1065 #[tokio::test]
1066 async fn test_clear_signal_cancel() {
1067 let backend = InMemoryBackend::new();
1068 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1069 backend.save_snapshot(&snapshot).await.unwrap();
1070
1071 backend
1072 .store_signal(
1073 "test-123",
1074 SignalKind::Cancel,
1075 SignalRequest::new(Some("Test".to_string()), None),
1076 )
1077 .await
1078 .unwrap();
1079
1080 assert!(
1081 backend
1082 .get_signal("test-123", SignalKind::Cancel)
1083 .await
1084 .unwrap()
1085 .is_some(),
1086 "cancel signal should exist before clearing"
1087 );
1088
1089 backend
1090 .clear_signal("test-123", SignalKind::Cancel)
1091 .await
1092 .unwrap();
1093
1094 assert!(
1095 backend
1096 .get_signal("test-123", SignalKind::Cancel)
1097 .await
1098 .unwrap()
1099 .is_none(),
1100 "cancel signal should be gone after clearing"
1101 );
1102 }
1103
1104 #[tokio::test]
1105 async fn test_store_signal_pause_completed_workflow() {
1106 let backend = InMemoryBackend::new();
1107 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1108 snapshot.mark_completed(bytes::Bytes::from("result"));
1109 backend.save_snapshot(&snapshot).await.unwrap();
1110
1111 let result = backend
1112 .store_signal(
1113 "test-123",
1114 SignalKind::Pause,
1115 SignalRequest::new(None, None),
1116 )
1117 .await;
1118 assert!(
1119 matches!(result, Err(BackendError::CannotPause(_))),
1120 "should return CannotPause for completed workflow"
1121 );
1122 }
1123
1124 #[tokio::test]
1125 async fn test_store_signal_pause_failed_workflow() {
1126 let backend = InMemoryBackend::new();
1127 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1128 snapshot.mark_failed("Some error".to_string());
1129 backend.save_snapshot(&snapshot).await.unwrap();
1130
1131 let result = backend
1132 .store_signal(
1133 "test-123",
1134 SignalKind::Pause,
1135 SignalRequest::new(None, None),
1136 )
1137 .await;
1138 assert!(
1139 matches!(result, Err(BackendError::CannotPause(_))),
1140 "should return CannotPause for failed workflow"
1141 );
1142 }
1143
1144 #[tokio::test]
1145 async fn test_store_signal_pause_cancelled_workflow() {
1146 let backend = InMemoryBackend::new();
1147 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1148 snapshot.mark_cancelled(Some("done".to_string()), None, None);
1149 backend.save_snapshot(&snapshot).await.unwrap();
1150
1151 let result = backend
1152 .store_signal(
1153 "test-123",
1154 SignalKind::Pause,
1155 SignalRequest::new(None, None),
1156 )
1157 .await;
1158 assert!(
1159 matches!(result, Err(BackendError::CannotPause(_))),
1160 "should return CannotPause for cancelled workflow"
1161 );
1162 }
1163
1164 #[tokio::test]
1165 async fn test_store_signal_pause_already_paused_idempotent() {
1166 let backend = InMemoryBackend::new();
1167 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1168 snapshot.mark_paused(&PauseRequest::new(Some("first".to_string()), None));
1169 backend.save_snapshot(&snapshot).await.unwrap();
1170
1171 let result = backend
1172 .store_signal(
1173 "test-123",
1174 SignalKind::Pause,
1175 SignalRequest::new(Some("second".to_string()), None),
1176 )
1177 .await;
1178 assert!(
1179 result.is_ok(),
1180 "pausing already-paused workflow should be idempotent"
1181 );
1182 }
1183
1184 #[tokio::test]
1185 async fn test_store_signal_pause_not_found() {
1186 let backend = InMemoryBackend::new();
1187 let result = backend
1188 .store_signal(
1189 "nonexistent",
1190 SignalKind::Pause,
1191 SignalRequest::new(None, None),
1192 )
1193 .await;
1194 assert!(
1195 matches!(result, Err(BackendError::NotFound(_))),
1196 "should return NotFound for non-existent workflow"
1197 );
1198 }
1199
1200 #[tokio::test]
1201 async fn test_check_and_cancel_success() {
1202 let backend = InMemoryBackend::new();
1203 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1204 backend.save_snapshot(&snapshot).await.unwrap();
1205
1206 backend
1207 .store_signal(
1208 "test-123",
1209 SignalKind::Cancel,
1210 SignalRequest::new(Some("Timeout".to_string()), Some("system".to_string())),
1211 )
1212 .await
1213 .unwrap();
1214
1215 let result = backend
1216 .check_and_cancel("test-123", Some("task-1"))
1217 .await
1218 .unwrap();
1219 assert!(
1220 result,
1221 "check_and_cancel should return true when cancellation pending"
1222 );
1223
1224 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1225 assert!(
1226 snapshot.state.is_cancelled(),
1227 "workflow should be in cancelled state"
1228 );
1229
1230 let WorkflowSnapshotState::Cancelled {
1231 reason,
1232 cancelled_by,
1233 interrupted_at_task,
1234 ..
1235 } = &snapshot.state
1236 else {
1237 panic!("Expected Cancelled state");
1238 };
1239 assert_eq!(reason, &Some("Timeout".to_string()));
1240 assert_eq!(cancelled_by, &Some("system".to_string()));
1241 assert_eq!(interrupted_at_task, &Some("task-1".to_string()));
1242
1243 assert!(
1244 backend
1245 .get_signal("test-123", SignalKind::Cancel)
1246 .await
1247 .unwrap()
1248 .is_none(),
1249 "cancel signal should be cleared after check_and_cancel"
1250 );
1251 }
1252
1253 #[tokio::test]
1254 async fn test_check_and_cancel_no_request() {
1255 let backend = InMemoryBackend::new();
1256 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1257 backend.save_snapshot(&snapshot).await.unwrap();
1258
1259 let result = backend.check_and_cancel("test-123", None).await.unwrap();
1260 assert!(
1261 !result,
1262 "check_and_cancel should return false when no cancellation pending"
1263 );
1264
1265 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1266 assert!(
1267 snapshot.state.is_in_progress(),
1268 "workflow should still be in progress"
1269 );
1270 }
1271
1272 #[tokio::test]
1273 async fn test_check_and_cancel_not_in_progress() {
1274 let backend = InMemoryBackend::new();
1275 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1276 snapshot.mark_completed(bytes::Bytes::from("done"));
1277 backend.save_snapshot(&snapshot).await.unwrap();
1278
1279 {
1281 let mut signals = backend.signals.write().unwrap();
1282 signals
1283 .entry("test-123".to_string())
1284 .or_default()
1285 .insert(SignalKind::Cancel, SignalRequest::new(None, None));
1286 }
1287
1288 let result = backend.check_and_cancel("test-123", None).await.unwrap();
1289 assert!(
1290 !result,
1291 "check_and_cancel should return false for non-in-progress workflow"
1292 );
1293
1294 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1295 assert!(
1296 snapshot.state.is_completed(),
1297 "workflow should still be completed"
1298 );
1299 }
1300
1301 #[tokio::test]
1302 async fn test_find_available_tasks_skips_cancelled_workflows() {
1303 let backend = InMemoryBackend::new();
1304
1305 let mut snapshot1 = WorkflowSnapshot::new("workflow-1".to_string(), "hash-abc".to_string());
1306 snapshot1.update_position(ExecutionPosition::AtTask {
1307 task_id: "task-1".to_string(),
1308 });
1309 backend.save_snapshot(&snapshot1).await.unwrap();
1310
1311 let mut snapshot2 = WorkflowSnapshot::new("workflow-2".to_string(), "hash-abc".to_string());
1312 snapshot2.update_position(ExecutionPosition::AtTask {
1313 task_id: "task-2".to_string(),
1314 });
1315 backend.save_snapshot(&snapshot2).await.unwrap();
1316
1317 backend
1318 .store_signal(
1319 "workflow-1",
1320 SignalKind::Cancel,
1321 SignalRequest::new(None, None),
1322 )
1323 .await
1324 .unwrap();
1325
1326 let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1327
1328 assert!(
1329 !tasks.iter().any(|t| t.instance_id == "workflow-1"),
1330 "workflow with pending cancellation should be skipped"
1331 );
1332 }
1333
1334 #[tokio::test]
1339 async fn test_check_and_pause_success() {
1340 let backend = InMemoryBackend::new();
1341 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1342 backend.save_snapshot(&snapshot).await.unwrap();
1343
1344 backend
1345 .store_signal(
1346 "test-123",
1347 SignalKind::Pause,
1348 SignalRequest::new(Some("maintenance".to_string()), Some("ops".to_string())),
1349 )
1350 .await
1351 .unwrap();
1352
1353 let result = backend.check_and_pause("test-123").await.unwrap();
1354 assert!(
1355 result,
1356 "check_and_pause should return true when pause pending"
1357 );
1358
1359 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1360 assert!(snapshot.state.is_paused(), "workflow should be paused");
1361
1362 let WorkflowSnapshotState::Paused {
1363 reason, paused_by, ..
1364 } = &snapshot.state
1365 else {
1366 panic!("Expected Paused state");
1367 };
1368 assert_eq!(reason, &Some("maintenance".to_string()));
1369 assert_eq!(paused_by, &Some("ops".to_string()));
1370
1371 assert!(
1372 backend
1373 .get_signal("test-123", SignalKind::Pause)
1374 .await
1375 .unwrap()
1376 .is_none(),
1377 "pause signal should be cleared after check_and_pause"
1378 );
1379 }
1380
1381 #[tokio::test]
1382 async fn test_check_and_pause_no_request() {
1383 let backend = InMemoryBackend::new();
1384 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1385 backend.save_snapshot(&snapshot).await.unwrap();
1386
1387 let result = backend.check_and_pause("test-123").await.unwrap();
1388 assert!(
1389 !result,
1390 "check_and_pause should return false when no pause pending"
1391 );
1392
1393 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1394 assert!(
1395 snapshot.state.is_in_progress(),
1396 "workflow should still be in progress"
1397 );
1398 }
1399
1400 #[tokio::test]
1401 async fn test_check_and_pause_not_in_progress() {
1402 let backend = InMemoryBackend::new();
1403 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1404 snapshot.mark_completed(bytes::Bytes::from("done"));
1405 backend.save_snapshot(&snapshot).await.unwrap();
1406
1407 {
1409 let mut signals = backend.signals.write().unwrap();
1410 signals
1411 .entry("test-123".to_string())
1412 .or_default()
1413 .insert(SignalKind::Pause, SignalRequest::new(None, None));
1414 }
1415
1416 let result = backend.check_and_pause("test-123").await.unwrap();
1417 assert!(
1418 !result,
1419 "check_and_pause should return false for non-in-progress workflow"
1420 );
1421
1422 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1423 assert!(
1424 snapshot.state.is_completed(),
1425 "workflow should still be completed"
1426 );
1427 }
1428
1429 #[tokio::test]
1430 async fn test_check_and_pause_preserves_position() {
1431 let backend = InMemoryBackend::new();
1432 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1433 snapshot.update_position(ExecutionPosition::AtTask {
1434 task_id: "task-3".to_string(),
1435 });
1436 snapshot.mark_task_completed("task-1".to_string(), bytes::Bytes::from("out1"));
1437 snapshot.mark_task_completed("task-2".to_string(), bytes::Bytes::from("out2"));
1438 backend.save_snapshot(&snapshot).await.unwrap();
1439
1440 backend
1441 .store_signal(
1442 "test-123",
1443 SignalKind::Pause,
1444 SignalRequest::new(None, None),
1445 )
1446 .await
1447 .unwrap();
1448
1449 backend.check_and_pause("test-123").await.unwrap();
1450
1451 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1452 let WorkflowSnapshotState::Paused {
1453 completed_tasks,
1454 position,
1455 last_completed_task_id,
1456 ..
1457 } = &snapshot.state
1458 else {
1459 panic!("Expected Paused state");
1460 };
1461
1462 assert_eq!(completed_tasks.len(), 2);
1463 assert!(completed_tasks.contains_key("task-1"));
1464 assert!(completed_tasks.contains_key("task-2"));
1465 assert!(matches!(
1466 position,
1467 ExecutionPosition::AtTask { task_id } if task_id == "task-3"
1468 ));
1469 assert_eq!(last_completed_task_id, &Some("task-2".to_string()));
1470 }
1471
1472 #[tokio::test]
1477 async fn test_unpause_success() {
1478 let backend = InMemoryBackend::new();
1479 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1480 snapshot.update_position(ExecutionPosition::AtTask {
1481 task_id: "task-2".to_string(),
1482 });
1483 snapshot.mark_task_completed("task-1".to_string(), bytes::Bytes::from("out1"));
1484 snapshot.mark_paused(&PauseRequest::new(
1485 Some("maintenance".to_string()),
1486 Some("ops".to_string()),
1487 ));
1488 backend.save_snapshot(&snapshot).await.unwrap();
1489
1490 let result = backend.unpause("test-123").await.unwrap();
1491
1492 assert!(
1493 result.state.is_in_progress(),
1494 "unpaused workflow should be in progress"
1495 );
1496
1497 let WorkflowSnapshotState::InProgress {
1499 position,
1500 completed_tasks,
1501 last_completed_task_id,
1502 } = &result.state
1503 else {
1504 panic!("Expected InProgress state");
1505 };
1506 assert!(matches!(
1507 position,
1508 ExecutionPosition::AtTask { task_id } if task_id == "task-2"
1509 ));
1510 assert!(completed_tasks.contains_key("task-1"));
1511 assert_eq!(last_completed_task_id, &Some("task-1".to_string()));
1512
1513 let loaded = backend.load_snapshot("test-123").await.unwrap();
1515 assert!(loaded.state.is_in_progress());
1516 }
1517
1518 #[tokio::test]
1519 async fn test_unpause_not_paused_errors() {
1520 let backend = InMemoryBackend::new();
1521 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1522 backend.save_snapshot(&snapshot).await.unwrap();
1523
1524 let result = backend.unpause("test-123").await;
1525 assert!(
1526 matches!(result, Err(BackendError::CannotPause(_))),
1527 "unpause on in-progress workflow should error"
1528 );
1529 }
1530
1531 #[tokio::test]
1532 async fn test_unpause_completed_errors() {
1533 let backend = InMemoryBackend::new();
1534 let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1535 snapshot.mark_completed(bytes::Bytes::from("done"));
1536 backend.save_snapshot(&snapshot).await.unwrap();
1537
1538 let result = backend.unpause("test-123").await;
1539 assert!(
1540 matches!(result, Err(BackendError::CannotPause(_))),
1541 "unpause on completed workflow should error"
1542 );
1543 }
1544
1545 #[tokio::test]
1546 async fn test_unpause_not_found() {
1547 let backend = InMemoryBackend::new();
1548 let result = backend.unpause("nonexistent").await;
1549 assert!(matches!(result, Err(BackendError::NotFound(_))));
1550 }
1551
1552 #[tokio::test]
1557 async fn test_cancel_and_pause_simultaneously_cancel_wins() {
1558 let backend = InMemoryBackend::new();
1559 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1560 backend.save_snapshot(&snapshot).await.unwrap();
1561
1562 backend
1564 .store_signal(
1565 "test-123",
1566 SignalKind::Cancel,
1567 SignalRequest::new(Some("cancel reason".to_string()), None),
1568 )
1569 .await
1570 .unwrap();
1571 backend
1572 .store_signal(
1573 "test-123",
1574 SignalKind::Pause,
1575 SignalRequest::new(Some("pause reason".to_string()), None),
1576 )
1577 .await
1578 .unwrap();
1579
1580 let cancelled = backend
1582 .check_and_cancel("test-123", Some("task-1"))
1583 .await
1584 .unwrap();
1585 assert!(cancelled, "cancel should succeed");
1586
1587 let paused = backend.check_and_pause("test-123").await.unwrap();
1589 assert!(
1590 !paused,
1591 "pause should return false since workflow is already cancelled"
1592 );
1593
1594 let snapshot = backend.load_snapshot("test-123").await.unwrap();
1595 assert!(snapshot.state.is_cancelled());
1596 }
1597
1598 #[tokio::test]
1599 async fn test_cancel_signal_independent_of_pause_signal() {
1600 let backend = InMemoryBackend::new();
1601 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1602 backend.save_snapshot(&snapshot).await.unwrap();
1603
1604 backend
1606 .store_signal(
1607 "test-123",
1608 SignalKind::Cancel,
1609 SignalRequest::new(Some("cancel".to_string()), None),
1610 )
1611 .await
1612 .unwrap();
1613 backend
1614 .store_signal(
1615 "test-123",
1616 SignalKind::Pause,
1617 SignalRequest::new(Some("pause".to_string()), None),
1618 )
1619 .await
1620 .unwrap();
1621
1622 backend
1624 .clear_signal("test-123", SignalKind::Cancel)
1625 .await
1626 .unwrap();
1627
1628 assert!(
1630 backend
1631 .get_signal("test-123", SignalKind::Cancel)
1632 .await
1633 .unwrap()
1634 .is_none()
1635 );
1636 assert!(
1637 backend
1638 .get_signal("test-123", SignalKind::Pause)
1639 .await
1640 .unwrap()
1641 .is_some()
1642 );
1643 }
1644
1645 #[tokio::test]
1650 async fn test_find_available_tasks_skips_paused_workflows() {
1651 let backend = InMemoryBackend::new();
1652
1653 let mut snapshot1 = WorkflowSnapshot::with_initial_input(
1654 "workflow-1".to_string(),
1655 "hash-abc".to_string(),
1656 bytes::Bytes::from(vec![1]),
1657 );
1658 snapshot1.update_position(ExecutionPosition::AtTask {
1659 task_id: "task-1".to_string(),
1660 });
1661 backend.save_snapshot(&snapshot1).await.unwrap();
1662
1663 let mut snapshot2 = WorkflowSnapshot::with_initial_input(
1664 "workflow-2".to_string(),
1665 "hash-abc".to_string(),
1666 bytes::Bytes::from(vec![2]),
1667 );
1668 snapshot2.update_position(ExecutionPosition::AtTask {
1669 task_id: "task-2".to_string(),
1670 });
1671 backend.save_snapshot(&snapshot2).await.unwrap();
1672
1673 backend
1675 .store_signal(
1676 "workflow-1",
1677 SignalKind::Pause,
1678 SignalRequest::new(None, None),
1679 )
1680 .await
1681 .unwrap();
1682
1683 let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1684
1685 assert!(
1686 !tasks.iter().any(|t| t.instance_id == "workflow-1"),
1687 "workflow with pending pause should be skipped"
1688 );
1689 assert!(
1690 tasks.iter().any(|t| t.instance_id == "workflow-2"),
1691 "workflow without signals should be available"
1692 );
1693 }
1694
1695 #[tokio::test]
1700 async fn test_delete_snapshot_leaves_orphaned_signals() {
1701 let backend = InMemoryBackend::new();
1702 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1703 backend.save_snapshot(&snapshot).await.unwrap();
1704
1705 backend
1706 .store_signal(
1707 "test-123",
1708 SignalKind::Cancel,
1709 SignalRequest::new(Some("reason".to_string()), None),
1710 )
1711 .await
1712 .unwrap();
1713
1714 backend.delete_snapshot("test-123").await.unwrap();
1716
1717 let signal = backend
1719 .get_signal("test-123", SignalKind::Cancel)
1720 .await
1721 .unwrap();
1722 assert!(
1723 signal.is_some(),
1724 "signal persists after snapshot deletion (orphaned)"
1725 );
1726 }
1727
1728 #[tokio::test]
1729 async fn test_store_signal_overwrites_previous() {
1730 let backend = InMemoryBackend::new();
1731 let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1732 backend.save_snapshot(&snapshot).await.unwrap();
1733
1734 backend
1735 .store_signal(
1736 "test-123",
1737 SignalKind::Cancel,
1738 SignalRequest::new(Some("first".to_string()), None),
1739 )
1740 .await
1741 .unwrap();
1742 backend
1743 .store_signal(
1744 "test-123",
1745 SignalKind::Cancel,
1746 SignalRequest::new(Some("second".to_string()), None),
1747 )
1748 .await
1749 .unwrap();
1750
1751 let signal = backend
1752 .get_signal("test-123", SignalKind::Cancel)
1753 .await
1754 .unwrap()
1755 .unwrap();
1756 assert_eq!(
1757 signal.reason,
1758 Some("second".to_string()),
1759 "latest signal should overwrite previous"
1760 );
1761 }
1762
1763 #[tokio::test]
1768 async fn test_find_available_tasks_skips_unexpired_delay() {
1769 let backend = InMemoryBackend::new();
1770
1771 let mut snapshot = WorkflowSnapshot::with_initial_input(
1772 "workflow-1".to_string(),
1773 "hash-abc".to_string(),
1774 bytes::Bytes::from(vec![42]),
1775 );
1776 let wake_at = Utc::now() + chrono::Duration::hours(1);
1778 snapshot.update_position(ExecutionPosition::AtDelay {
1779 delay_id: "wait_1h".to_string(),
1780 entered_at: Utc::now(),
1781 wake_at,
1782 next_task_id: Some("next_step".to_string()),
1783 });
1784 snapshot.mark_task_completed("wait_1h".to_string(), bytes::Bytes::from(vec![42]));
1785 backend.save_snapshot(&snapshot).await.unwrap();
1786
1787 let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1788 assert!(
1789 tasks.is_empty(),
1790 "workflow at unexpired delay should not appear in available tasks"
1791 );
1792 }
1793
1794 #[tokio::test]
1795 async fn test_find_available_tasks_advances_expired_delay() {
1796 let backend = InMemoryBackend::new();
1797
1798 let mut snapshot = WorkflowSnapshot::with_initial_input(
1799 "workflow-1".to_string(),
1800 "hash-abc".to_string(),
1801 bytes::Bytes::from(vec![42]),
1802 );
1803 let wake_at = Utc::now() - chrono::Duration::seconds(1);
1805 snapshot.update_position(ExecutionPosition::AtDelay {
1806 delay_id: "wait_done".to_string(),
1807 entered_at: Utc::now() - chrono::Duration::seconds(2),
1808 wake_at,
1809 next_task_id: Some("process".to_string()),
1810 });
1811 snapshot.mark_task_completed("wait_done".to_string(), bytes::Bytes::from(vec![42]));
1812 backend.save_snapshot(&snapshot).await.unwrap();
1813
1814 let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1815
1816 assert_eq!(tasks.len(), 1);
1818 assert_eq!(tasks[0].instance_id, "workflow-1");
1819 assert_eq!(tasks[0].task_id, "process");
1820
1821 let loaded = backend.load_snapshot("workflow-1").await.unwrap();
1823 match &loaded.state {
1824 WorkflowSnapshotState::InProgress {
1825 position: ExecutionPosition::AtTask { task_id },
1826 ..
1827 } => {
1828 assert_eq!(task_id, "process");
1829 }
1830 other => panic!("Expected AtTask position, got {other:?}"),
1831 }
1832 }
1833
1834 #[tokio::test]
1835 async fn test_find_available_tasks_completes_expired_delay_last_node() {
1836 let backend = InMemoryBackend::new();
1837
1838 let mut snapshot = WorkflowSnapshot::with_initial_input(
1839 "workflow-1".to_string(),
1840 "hash-abc".to_string(),
1841 bytes::Bytes::from(vec![42]),
1842 );
1843 let wake_at = Utc::now() - chrono::Duration::seconds(1);
1845 snapshot.update_position(ExecutionPosition::AtDelay {
1846 delay_id: "final_wait".to_string(),
1847 entered_at: Utc::now() - chrono::Duration::seconds(2),
1848 wake_at,
1849 next_task_id: None,
1850 });
1851 snapshot.mark_task_completed("final_wait".to_string(), bytes::Bytes::from(vec![42]));
1852 backend.save_snapshot(&snapshot).await.unwrap();
1853
1854 let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1855
1856 assert!(
1858 tasks.is_empty(),
1859 "completed workflow should not appear in available tasks"
1860 );
1861
1862 let loaded = backend.load_snapshot("workflow-1").await.unwrap();
1864 assert!(
1865 loaded.state.is_completed(),
1866 "workflow should be completed when delay is last node and expired"
1867 );
1868 }
1869
1870 #[tokio::test]
1873 async fn test_send_and_consume_event_fifo() {
1874 let backend = InMemoryBackend::new();
1875
1876 backend
1878 .send_event("wf-1", "approval", bytes::Bytes::from("first"))
1879 .await
1880 .unwrap();
1881 backend
1882 .send_event("wf-1", "approval", bytes::Bytes::from("second"))
1883 .await
1884 .unwrap();
1885
1886 let first = backend.consume_event("wf-1", "approval").await.unwrap();
1888 assert_eq!(first.as_deref(), Some(b"first".as_slice()));
1889
1890 let second = backend.consume_event("wf-1", "approval").await.unwrap();
1891 assert_eq!(second.as_deref(), Some(b"second".as_slice()));
1892
1893 let none = backend.consume_event("wf-1", "approval").await.unwrap();
1895 assert!(none.is_none());
1896 }
1897
1898 #[tokio::test]
1899 async fn test_consume_event_empty_returns_none() {
1900 let backend = InMemoryBackend::new();
1901 let result = backend.consume_event("wf-1", "nonexistent").await.unwrap();
1902 assert!(result.is_none());
1903 }
1904
1905 #[tokio::test]
1906 async fn test_events_are_isolated_by_signal_name() {
1907 let backend = InMemoryBackend::new();
1908
1909 backend
1910 .send_event("wf-1", "sig_a", bytes::Bytes::from("a_payload"))
1911 .await
1912 .unwrap();
1913 backend
1914 .send_event("wf-1", "sig_b", bytes::Bytes::from("b_payload"))
1915 .await
1916 .unwrap();
1917
1918 let a = backend.consume_event("wf-1", "sig_a").await.unwrap();
1920 assert_eq!(a.as_deref(), Some(b"a_payload".as_slice()));
1921
1922 let b = backend.consume_event("wf-1", "sig_b").await.unwrap();
1923 assert_eq!(b.as_deref(), Some(b"b_payload".as_slice()));
1924 }
1925}