Skip to main content

sayiir_persistence/
in_memory.rs

1//! In-memory implementation of the persistence traits.
2//!
3//! This is a simple implementation that stores snapshots in a HashMap.
4//! Useful for testing and as a reference implementation.
5
6use crate::backend::{BackendError, SignalStore, SnapshotStore, TaskClaimStore};
7use chrono::{Duration, Utc};
8use sayiir_core::snapshot::{
9    ExecutionPosition, PauseRequest, SignalKind, SignalRequest, WorkflowSnapshot,
10    WorkflowSnapshotState,
11};
12use sayiir_core::task_claim::{AvailableTask, TaskClaim};
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, RwLock};
15
16/// In-memory backend that stores snapshots in a HashMap.
17///
18/// This implementation is thread-safe and suitable for testing.
19/// For production use, consider implementing the persistence traits for
20/// a more durable storage backend (Redis, PostgreSQL, etc.).
21#[derive(Clone, Default)]
22pub struct InMemoryBackend {
23    snapshots: Arc<RwLock<HashMap<String, WorkflowSnapshot>>>,
24    claims: Arc<RwLock<HashMap<String, TaskClaim>>>, // Key: "{instance_id}:{task_id}"
25    signals: Arc<RwLock<HashMap<String, HashMap<SignalKind, SignalRequest>>>>,
26    /// Buffered external events per `(instance_id, signal_name)`, FIFO order.
27    #[allow(clippy::type_complexity)]
28    events: Arc<RwLock<HashMap<(String, String), VecDeque<bytes::Bytes>>>>,
29}
30
31impl InMemoryBackend {
32    /// Create a new in-memory backend.
33    pub fn new() -> Self {
34        Default::default()
35    }
36
37    fn claim_key(instance_id: &str, task_id: &str) -> String {
38        format!("{}:{}", instance_id, task_id)
39    }
40
41    /// Convert a lock error into a BackendError.
42    fn lock_error<E: std::fmt::Display>(e: E) -> BackendError {
43        BackendError::Backend(format!("Lock error: {e}"))
44    }
45}
46
47// ---------------------------------------------------------------------------
48// SnapshotStore
49// ---------------------------------------------------------------------------
50
51impl SnapshotStore for InMemoryBackend {
52    async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
53        let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
54        snapshots.insert(snapshot.instance_id.clone(), snapshot.clone());
55        Ok(())
56    }
57
58    async fn save_task_result(
59        &self,
60        instance_id: &str,
61        task_id: &str,
62        output: bytes::Bytes,
63    ) -> Result<(), BackendError> {
64        let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
65
66        let snapshot = snapshots
67            .get_mut(instance_id)
68            .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
69
70        snapshot.mark_task_completed(task_id.to_string(), output);
71        Ok(())
72    }
73
74    async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
75        let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
76        snapshots
77            .get(instance_id)
78            .cloned()
79            .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))
80    }
81
82    async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
83        let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
84        snapshots
85            .remove(instance_id)
86            .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))
87            .map(|_| ())
88    }
89
90    async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
91        let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
92        Ok(snapshots.keys().cloned().collect())
93    }
94}
95
96// ---------------------------------------------------------------------------
97// SignalStore (overrides default composites for lock efficiency)
98// ---------------------------------------------------------------------------
99
100impl SignalStore for InMemoryBackend {
101    async fn store_signal(
102        &self,
103        instance_id: &str,
104        kind: SignalKind,
105        request: SignalRequest,
106    ) -> Result<(), BackendError> {
107        // Validate that the workflow exists and is in a signalable state
108        {
109            let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
110            let snapshot = snapshots
111                .get(instance_id)
112                .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
113
114            match kind {
115                SignalKind::Cancel => {
116                    if snapshot.state.is_completed() {
117                        return Err(BackendError::CannotCancel("Completed".to_string()));
118                    }
119                    if snapshot.state.is_failed() {
120                        return Err(BackendError::CannotCancel("Failed".to_string()));
121                    }
122                    if snapshot.state.is_cancelled() {
123                        return Ok(()); // idempotent
124                    }
125                }
126                SignalKind::Pause => {
127                    if snapshot.state.is_completed() {
128                        return Err(BackendError::CannotPause("Completed".to_string()));
129                    }
130                    if snapshot.state.is_failed() {
131                        return Err(BackendError::CannotPause("Failed".to_string()));
132                    }
133                    if snapshot.state.is_cancelled() {
134                        return Err(BackendError::CannotPause("Cancelled".to_string()));
135                    }
136                    if snapshot.state.is_paused() {
137                        return Ok(()); // idempotent
138                    }
139                }
140            }
141        }
142
143        let mut signals = self.signals.write().map_err(Self::lock_error)?;
144        signals
145            .entry(instance_id.to_string())
146            .or_default()
147            .insert(kind, request);
148        Ok(())
149    }
150
151    async fn get_signal(
152        &self,
153        instance_id: &str,
154        kind: SignalKind,
155    ) -> Result<Option<SignalRequest>, BackendError> {
156        let signals = self.signals.read().map_err(Self::lock_error)?;
157        Ok(signals.get(instance_id).and_then(|m| m.get(&kind)).cloned())
158    }
159
160    async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
161        let mut signals = self.signals.write().map_err(Self::lock_error)?;
162        if let Some(inner) = signals.get_mut(instance_id) {
163            inner.remove(&kind);
164            if inner.is_empty() {
165                signals.remove(instance_id);
166            }
167        }
168        Ok(())
169    }
170
171    async fn send_event(
172        &self,
173        instance_id: &str,
174        signal_name: &str,
175        payload: bytes::Bytes,
176    ) -> Result<(), BackendError> {
177        let mut events = self.events.write().map_err(Self::lock_error)?;
178        events
179            .entry((instance_id.to_string(), signal_name.to_string()))
180            .or_default()
181            .push_back(payload);
182        Ok(())
183    }
184
185    async fn consume_event(
186        &self,
187        instance_id: &str,
188        signal_name: &str,
189    ) -> Result<Option<bytes::Bytes>, BackendError> {
190        let mut events = self.events.write().map_err(Self::lock_error)?;
191        let key = (instance_id.to_string(), signal_name.to_string());
192        let payload = events.get_mut(&key).and_then(VecDeque::pop_front);
193        // Clean up empty queues
194        if events.get(&key).is_some_and(VecDeque::is_empty) {
195            events.remove(&key);
196        }
197        Ok(payload)
198    }
199
200    // Override check_and_cancel for more efficient locking (avoids load+save round-trip).
201    async fn check_and_cancel(
202        &self,
203        instance_id: &str,
204        interrupted_at_task: Option<&str>,
205    ) -> Result<bool, BackendError> {
206        let request = {
207            let signals = self.signals.read().map_err(Self::lock_error)?;
208            match signals
209                .get(instance_id)
210                .and_then(|m| m.get(&SignalKind::Cancel))
211            {
212                Some(req) => req.clone(),
213                None => return Ok(false),
214            }
215        };
216
217        {
218            let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
219            let snapshot = snapshots
220                .get_mut(instance_id)
221                .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
222            if !snapshot.state.is_in_progress() {
223                return Ok(false);
224            }
225            snapshot.mark_cancelled(
226                request.reason,
227                request.requested_by,
228                interrupted_at_task.map(String::from),
229            );
230        }
231
232        {
233            let mut signals = self.signals.write().map_err(Self::lock_error)?;
234            if let Some(inner) = signals.get_mut(instance_id) {
235                inner.remove(&SignalKind::Cancel);
236                if inner.is_empty() {
237                    signals.remove(instance_id);
238                }
239            }
240        }
241
242        Ok(true)
243    }
244
245    // Override check_and_pause for more efficient locking.
246    async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
247        let request = {
248            let signals = self.signals.read().map_err(Self::lock_error)?;
249            match signals
250                .get(instance_id)
251                .and_then(|m| m.get(&SignalKind::Pause))
252            {
253                Some(req) => req.clone(),
254                None => return Ok(false),
255            }
256        };
257
258        {
259            let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
260            let snapshot = snapshots
261                .get_mut(instance_id)
262                .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
263            if !snapshot.state.is_in_progress() {
264                return Ok(false);
265            }
266            let pause_request: PauseRequest = request.into();
267            snapshot.mark_paused(&pause_request);
268        }
269
270        {
271            let mut signals = self.signals.write().map_err(Self::lock_error)?;
272            if let Some(inner) = signals.get_mut(instance_id) {
273                inner.remove(&SignalKind::Pause);
274                if inner.is_empty() {
275                    signals.remove(instance_id);
276                }
277            }
278        }
279
280        Ok(true)
281    }
282
283    // Override unpause for more efficient locking.
284    async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
285        let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
286
287        let snapshot = snapshots
288            .get_mut(instance_id)
289            .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
290
291        if !snapshot.state.is_paused() {
292            return Err(BackendError::CannotPause(format!(
293                "Workflow is not paused (current state: {:?})",
294                if snapshot.state.is_in_progress() {
295                    "InProgress"
296                } else if snapshot.state.is_completed() {
297                    "Completed"
298                } else if snapshot.state.is_failed() {
299                    "Failed"
300                } else if snapshot.state.is_cancelled() {
301                    "Cancelled"
302                } else {
303                    "Unknown"
304                }
305            )));
306        }
307
308        snapshot.mark_unpaused();
309        Ok(snapshot.clone())
310    }
311}
312
313// ---------------------------------------------------------------------------
314// TaskClaimStore
315// ---------------------------------------------------------------------------
316
317impl TaskClaimStore for InMemoryBackend {
318    async fn claim_task(
319        &self,
320        instance_id: &str,
321        task_id: &str,
322        worker_id: &str,
323        ttl: Option<Duration>,
324    ) -> Result<Option<TaskClaim>, BackendError> {
325        let key = Self::claim_key(instance_id, task_id);
326        let mut claims = self.claims.write().map_err(Self::lock_error)?;
327
328        // Check if already claimed and not expired
329        if let Some(existing_claim) = claims.get(&key) {
330            if !existing_claim.is_expired() {
331                return Ok(None); // Already claimed
332            }
333            // Expired claim, remove it
334            claims.remove(&key);
335        }
336
337        // Create new claim
338        let claim = TaskClaim::new(
339            instance_id.to_string(),
340            task_id.to_string(),
341            worker_id.to_string(),
342            ttl,
343        );
344        claims.insert(key, claim.clone());
345        Ok(Some(claim))
346    }
347
348    async fn release_task_claim(
349        &self,
350        instance_id: &str,
351        task_id: &str,
352        worker_id: &str,
353    ) -> Result<(), BackendError> {
354        let key = Self::claim_key(instance_id, task_id);
355        let mut claims = self.claims.write().map_err(Self::lock_error)?;
356
357        if let Some(claim) = claims.get(&key) {
358            if claim.worker_id != worker_id {
359                return Err(BackendError::Backend(format!(
360                    "Claim owned by different worker: {}",
361                    claim.worker_id
362                )));
363            }
364            claims.remove(&key);
365            Ok(())
366        } else {
367            Err(BackendError::NotFound(format!(
368                "{}:{}",
369                instance_id, task_id
370            )))
371        }
372    }
373
374    async fn extend_task_claim(
375        &self,
376        instance_id: &str,
377        task_id: &str,
378        worker_id: &str,
379        additional_duration: Duration,
380    ) -> Result<(), BackendError> {
381        let key = Self::claim_key(instance_id, task_id);
382        let mut claims = self.claims.write().map_err(Self::lock_error)?;
383
384        if let Some(claim) = claims.get_mut(&key) {
385            if claim.worker_id != worker_id {
386                return Err(BackendError::Backend(format!(
387                    "Claim owned by different worker: {}",
388                    claim.worker_id
389                )));
390            }
391
392            if let Some(expires_at) = claim.expires_at {
393                let expires_datetime = chrono::DateTime::from_timestamp(expires_at as i64, 0)
394                    .ok_or_else(|| BackendError::Backend("Invalid timestamp".to_string()))?;
395                let new_expiry = expires_datetime
396                    .checked_add_signed(additional_duration)
397                    .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
398                claim.expires_at = Some(new_expiry.timestamp() as u64);
399            }
400            Ok(())
401        } else {
402            Err(BackendError::NotFound(format!(
403                "{}:{}",
404                instance_id, task_id
405            )))
406        }
407    }
408
409    async fn find_available_tasks(
410        &self,
411        worker_id: &str,
412        limit: usize,
413    ) -> Result<Vec<AvailableTask>, BackendError> {
414        // Clean up expired claims first
415        {
416            let mut claims = self.claims.write().map_err(Self::lock_error)?;
417            claims.retain(|_, claim| !claim.is_expired());
418        }
419
420        // Collect delay-expired workflows that need position advancement
421        let mut delay_advances: Vec<(String, String)> = Vec::new();
422        let mut delay_completions: Vec<(String, String)> = Vec::new();
423        // Signal-related advancements: (instance_id, signal_name, next_task_id_or_none)
424        let mut signal_advances: Vec<(String, String, Option<String>)> = Vec::new();
425        // Signal timeout expirations: (instance_id, signal_id, next_task_id_or_none)
426        let mut signal_timeout_advances: Vec<(String, String, Option<String>)> = Vec::new();
427
428        {
429            let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
430            let signals = self.signals.read().map_err(Self::lock_error)?;
431            let events = self.events.read().map_err(Self::lock_error)?;
432
433            for (instance_id, snapshot) in snapshots.iter() {
434                if !snapshot.state.is_in_progress() {
435                    continue;
436                }
437                if signals
438                    .get(instance_id.as_str())
439                    .is_some_and(|m| m.contains_key(&SignalKind::Cancel))
440                {
441                    continue;
442                }
443                if signals
444                    .get(instance_id.as_str())
445                    .is_some_and(|m| m.contains_key(&SignalKind::Pause))
446                {
447                    continue;
448                }
449                match &snapshot.state {
450                    WorkflowSnapshotState::InProgress {
451                        position:
452                            ExecutionPosition::AtDelay {
453                                wake_at,
454                                delay_id,
455                                next_task_id,
456                                ..
457                            },
458                        ..
459                    } if Utc::now() >= *wake_at => {
460                        if let Some(next_id) = next_task_id {
461                            delay_advances.push((instance_id.clone(), next_id.clone()));
462                        } else {
463                            delay_completions.push((instance_id.clone(), delay_id.clone()));
464                        }
465                    }
466                    WorkflowSnapshotState::InProgress {
467                        position:
468                            ExecutionPosition::AtSignal {
469                                signal_name,
470                                wake_at,
471                                next_task_id,
472                                ..
473                            },
474                        ..
475                    } => {
476                        let key = (instance_id.clone(), signal_name.clone());
477                        if events.get(&key).is_some_and(|q| !q.is_empty()) {
478                            // Signal arrived — advance
479                            signal_advances.push((
480                                instance_id.clone(),
481                                signal_name.clone(),
482                                next_task_id.clone(),
483                            ));
484                        } else if wake_at.is_some_and(|wt| Utc::now() >= wt) {
485                            // Timeout expired — advance with None payload
486                            signal_timeout_advances.push((
487                                instance_id.clone(),
488                                signal_name.clone(),
489                                next_task_id.clone(),
490                            ));
491                        }
492                    }
493                    _ => {}
494                }
495            }
496        }
497
498        // Apply delay advancements with write lock
499        if !delay_advances.is_empty()
500            || !delay_completions.is_empty()
501            || !signal_advances.is_empty()
502            || !signal_timeout_advances.is_empty()
503        {
504            let mut snapshots = self.snapshots.write().map_err(Self::lock_error)?;
505            for (instance_id, next_task_id) in &delay_advances {
506                if let Some(snapshot) = snapshots.get_mut(instance_id) {
507                    snapshot.update_position(ExecutionPosition::AtTask {
508                        task_id: next_task_id.clone(),
509                    });
510                }
511            }
512            for (instance_id, delay_id) in &delay_completions {
513                if let Some(snapshot) = snapshots.get_mut(instance_id) {
514                    let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
515                    snapshot.mark_completed(output);
516                }
517            }
518            // Consume signal events and advance position
519            {
520                let mut events = self.events.write().map_err(Self::lock_error)?;
521                for (instance_id, signal_name, next_task_id) in &signal_advances {
522                    let key = (instance_id.clone(), signal_name.clone());
523                    let payload = events
524                        .get_mut(&key)
525                        .and_then(VecDeque::pop_front)
526                        .unwrap_or_default();
527                    // Clean up empty queues
528                    if events.get(&key).is_some_and(VecDeque::is_empty) {
529                        events.remove(&key);
530                    }
531                    if let Some(snapshot) = snapshots.get_mut(instance_id) {
532                        // Store signal payload as a task result so the next step can use it
533                        snapshot.mark_task_completed(signal_name.clone(), payload);
534                        if let Some(next_id) = next_task_id {
535                            snapshot.update_position(ExecutionPosition::AtTask {
536                                task_id: next_id.clone(),
537                            });
538                        } else {
539                            let output = snapshot
540                                .get_task_result_bytes(signal_name)
541                                .unwrap_or_default();
542                            snapshot.mark_completed(output);
543                        }
544                    }
545                }
546            }
547            // Handle signal timeouts (advance with empty payload)
548            for (instance_id, signal_name, next_task_id) in &signal_timeout_advances {
549                if let Some(snapshot) = snapshots.get_mut(instance_id) {
550                    snapshot.mark_task_completed(signal_name.clone(), bytes::Bytes::new());
551                    if let Some(next_id) = next_task_id {
552                        snapshot.update_position(ExecutionPosition::AtTask {
553                            task_id: next_id.clone(),
554                        });
555                    } else {
556                        snapshot.mark_completed(bytes::Bytes::new());
557                    }
558                }
559            }
560        }
561
562        let snapshots = self.snapshots.read().map_err(Self::lock_error)?;
563        let claims = self.claims.read().map_err(Self::lock_error)?;
564        let signals = self.signals.read().map_err(Self::lock_error)?;
565
566        let mut available = Vec::new();
567
568        for (instance_id, snapshot) in snapshots.iter() {
569            if !snapshot.state.is_in_progress() {
570                continue;
571            }
572
573            // Skip workflows with pending cancellation or pause requests
574            if let Some(instance_signals) = signals.get(instance_id.as_str())
575                && (instance_signals.contains_key(&SignalKind::Cancel)
576                    || instance_signals.contains_key(&SignalKind::Pause))
577            {
578                continue;
579            }
580
581            if let WorkflowSnapshotState::InProgress {
582                completed_tasks,
583                position: ExecutionPosition::AtTask { task_id },
584                ..
585            } = &snapshot.state
586            {
587                let claim_key = Self::claim_key(instance_id, task_id);
588                let is_claimed = claims.contains_key(&claim_key);
589                let is_completed = completed_tasks.contains_key(task_id);
590
591                if !is_completed && !is_claimed {
592                    // Skip tasks whose retry backoff has not elapsed yet
593                    if let Some(rs) = snapshot.task_retries.get(task_id)
594                        && Utc::now() < rs.next_retry_at
595                    {
596                        continue;
597                    }
598
599                    let input = if completed_tasks.is_empty() {
600                        snapshot.initial_input_bytes()
601                    } else {
602                        snapshot.get_last_task_output()
603                    };
604
605                    if let Some(input_bytes) = input {
606                        available.push(AvailableTask {
607                            instance_id: instance_id.clone(),
608                            task_id: task_id.clone(),
609                            input: input_bytes,
610                            workflow_definition_hash: snapshot.definition_hash.clone(),
611                        });
612
613                        if available.len() >= limit {
614                            break;
615                        }
616                    }
617                }
618            }
619        }
620
621        // Soft worker bias: tasks that did NOT fail on this worker come first.
622        // `false < true`, so non-failed-on-this-worker tasks are sorted first.
623        available.sort_by_key(|t| {
624            snapshots
625                .get(&t.instance_id)
626                .and_then(|s| s.task_retries.get(&t.task_id))
627                .and_then(|rs| rs.last_failed_worker.as_deref())
628                .is_some_and(|w| w == worker_id)
629        });
630
631        Ok(available)
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use crate::backend::SignalStore;
639    use sayiir_core::snapshot::SignalKind;
640
641    #[tokio::test]
642    async fn test_save_and_load() {
643        let backend = InMemoryBackend::new();
644        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
645
646        backend.save_snapshot(&snapshot).await.unwrap();
647        let loaded = backend.load_snapshot("test-123").await.unwrap();
648
649        assert_eq!(snapshot.instance_id, loaded.instance_id);
650        assert_eq!(snapshot.definition_hash, loaded.definition_hash);
651    }
652
653    #[tokio::test]
654    async fn test_not_found() {
655        let backend = InMemoryBackend::new();
656        let result = backend.load_snapshot("nonexistent").await;
657        assert!(matches!(result, Err(BackendError::NotFound(_))));
658    }
659
660    #[tokio::test]
661    async fn test_delete() {
662        let backend = InMemoryBackend::new();
663        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
664
665        backend.save_snapshot(&snapshot).await.unwrap();
666        backend.delete_snapshot("test-123").await.unwrap();
667
668        let result = backend.load_snapshot("test-123").await;
669        assert!(matches!(result, Err(BackendError::NotFound(_))));
670    }
671
672    #[tokio::test]
673    async fn test_list_snapshots() {
674        let backend = InMemoryBackend::new();
675
676        backend
677            .save_snapshot(&WorkflowSnapshot::new(
678                "test-1".to_string(),
679                "hash-1".to_string(),
680            ))
681            .await
682            .unwrap();
683        backend
684            .save_snapshot(&WorkflowSnapshot::new(
685                "test-2".to_string(),
686                "hash-2".to_string(),
687            ))
688            .await
689            .unwrap();
690
691        let list = backend.list_snapshots().await.unwrap();
692        assert_eq!(list.len(), 2);
693        assert!(list.contains(&"test-1".to_string()));
694        assert!(list.contains(&"test-2".to_string()));
695    }
696
697    // Task claim tests
698
699    #[tokio::test]
700    async fn test_claim_task_success() {
701        let backend = InMemoryBackend::new();
702
703        let claim = backend
704            .claim_task(
705                "workflow-1",
706                "task-1",
707                "worker-1",
708                Some(Duration::seconds(300)),
709            )
710            .await
711            .unwrap();
712
713        assert!(claim.is_some());
714        let claim = claim.unwrap();
715        assert_eq!(claim.instance_id, "workflow-1");
716        assert_eq!(claim.task_id, "task-1");
717        assert_eq!(claim.worker_id, "worker-1");
718        assert!(claim.expires_at.is_some());
719    }
720
721    #[tokio::test]
722    async fn test_claim_task_already_claimed() {
723        let backend = InMemoryBackend::new();
724
725        // First claim succeeds
726        let claim1 = backend
727            .claim_task(
728                "workflow-1",
729                "task-1",
730                "worker-1",
731                Some(Duration::seconds(300)),
732            )
733            .await
734            .unwrap();
735        assert!(claim1.is_some());
736
737        // Second claim by different worker fails
738        let claim2 = backend
739            .claim_task(
740                "workflow-1",
741                "task-1",
742                "worker-2",
743                Some(Duration::seconds(300)),
744            )
745            .await
746            .unwrap();
747        assert!(claim2.is_none());
748    }
749
750    #[tokio::test]
751    async fn test_claim_task_expired_claim_replaced() {
752        let backend = InMemoryBackend::new();
753
754        // Create a claim with 0 TTL (immediately expired)
755        let claim1 = backend
756            .claim_task(
757                "workflow-1",
758                "task-1",
759                "worker-1",
760                Some(Duration::seconds(0)),
761            )
762            .await
763            .unwrap();
764        assert!(claim1.is_some());
765
766        // Second claim should succeed because first is expired (0-second TTL)
767        let claim2 = backend
768            .claim_task(
769                "workflow-1",
770                "task-1",
771                "worker-2",
772                Some(Duration::seconds(300)),
773            )
774            .await
775            .unwrap();
776        assert!(claim2.is_some());
777        let claim2 = claim2.unwrap();
778        assert_eq!(claim2.worker_id, "worker-2");
779    }
780
781    #[tokio::test]
782    async fn test_claim_task_no_ttl() {
783        let backend = InMemoryBackend::new();
784
785        let claim = backend
786            .claim_task("workflow-1", "task-1", "worker-1", None)
787            .await
788            .unwrap();
789
790        assert!(claim.is_some());
791        let claim = claim.unwrap();
792        assert!(claim.expires_at.is_none());
793        assert!(!claim.is_expired()); // Never expires
794    }
795
796    #[tokio::test]
797    async fn test_release_task_claim_success() {
798        let backend = InMemoryBackend::new();
799
800        // Claim a task
801        backend
802            .claim_task(
803                "workflow-1",
804                "task-1",
805                "worker-1",
806                Some(Duration::seconds(300)),
807            )
808            .await
809            .unwrap();
810
811        // Release it
812        let result = backend
813            .release_task_claim("workflow-1", "task-1", "worker-1")
814            .await;
815        assert!(result.is_ok());
816
817        // Can claim again
818        let claim = backend
819            .claim_task(
820                "workflow-1",
821                "task-1",
822                "worker-2",
823                Some(Duration::seconds(300)),
824            )
825            .await
826            .unwrap();
827        assert!(claim.is_some());
828    }
829
830    #[tokio::test]
831    async fn test_release_task_claim_wrong_worker() {
832        let backend = InMemoryBackend::new();
833
834        // Claim a task as worker-1
835        backend
836            .claim_task(
837                "workflow-1",
838                "task-1",
839                "worker-1",
840                Some(Duration::seconds(300)),
841            )
842            .await
843            .unwrap();
844
845        // Try to release as worker-2
846        let result = backend
847            .release_task_claim("workflow-1", "task-1", "worker-2")
848            .await;
849        assert!(matches!(result, Err(BackendError::Backend(_))));
850    }
851
852    #[tokio::test]
853    async fn test_release_task_claim_not_found() {
854        let backend = InMemoryBackend::new();
855
856        let result = backend
857            .release_task_claim("workflow-1", "task-1", "worker-1")
858            .await;
859        assert!(matches!(result, Err(BackendError::NotFound(_))));
860    }
861
862    #[tokio::test]
863    async fn test_extend_task_claim_success() {
864        let backend = InMemoryBackend::new();
865
866        // Claim a task with short TTL
867        let claim = backend
868            .claim_task(
869                "workflow-1",
870                "task-1",
871                "worker-1",
872                Some(Duration::seconds(10)),
873            )
874            .await
875            .unwrap()
876            .unwrap();
877        let original_expiry = claim.expires_at.unwrap();
878
879        // Extend it
880        backend
881            .extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
882            .await
883            .unwrap();
884
885        // Verify extension by checking internal state
886        let claims = backend.claims.read().unwrap();
887        let key = InMemoryBackend::claim_key("workflow-1", "task-1");
888        let extended_claim = claims.get(&key).unwrap();
889        assert!(extended_claim.expires_at.unwrap() > original_expiry);
890    }
891
892    #[tokio::test]
893    async fn test_extend_task_claim_wrong_worker() {
894        let backend = InMemoryBackend::new();
895
896        // Claim a task as worker-1
897        backend
898            .claim_task(
899                "workflow-1",
900                "task-1",
901                "worker-1",
902                Some(Duration::seconds(300)),
903            )
904            .await
905            .unwrap();
906
907        // Try to extend as worker-2
908        let result = backend
909            .extend_task_claim("workflow-1", "task-1", "worker-2", Duration::seconds(300))
910            .await;
911        assert!(matches!(result, Err(BackendError::Backend(_))));
912    }
913
914    #[tokio::test]
915    async fn test_extend_task_claim_not_found() {
916        let backend = InMemoryBackend::new();
917
918        let result = backend
919            .extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
920            .await;
921        assert!(matches!(result, Err(BackendError::NotFound(_))));
922    }
923
924    #[tokio::test]
925    async fn test_extend_task_claim_no_expiry() {
926        let backend = InMemoryBackend::new();
927
928        // Claim a task with no TTL
929        backend
930            .claim_task("workflow-1", "task-1", "worker-1", None)
931            .await
932            .unwrap();
933
934        // Extending should succeed but not change anything (expires_at stays None)
935        backend
936            .extend_task_claim("workflow-1", "task-1", "worker-1", Duration::seconds(300))
937            .await
938            .unwrap();
939
940        let claims = backend.claims.read().unwrap();
941        let key = InMemoryBackend::claim_key("workflow-1", "task-1");
942        let claim = claims.get(&key).unwrap();
943        assert!(claim.expires_at.is_none());
944    }
945
946    #[tokio::test]
947    async fn test_store_signal_cancel_success() {
948        let backend = InMemoryBackend::new();
949        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
950        backend.save_snapshot(&snapshot).await.unwrap();
951
952        let result = backend
953            .store_signal(
954                "test-123",
955                SignalKind::Cancel,
956                SignalRequest::new(
957                    Some("User requested".to_string()),
958                    Some("admin".to_string()),
959                ),
960            )
961            .await;
962        assert!(result.is_ok(), "store_signal should succeed");
963
964        let stored = backend
965            .get_signal("test-123", SignalKind::Cancel)
966            .await
967            .unwrap();
968        assert!(stored.is_some(), "cancel signal should be stored");
969        let stored = stored.unwrap();
970        assert_eq!(stored.reason, Some("User requested".to_string()));
971        assert_eq!(stored.requested_by, Some("admin".to_string()));
972    }
973
974    #[tokio::test]
975    async fn test_store_signal_cancel_not_found() {
976        let backend = InMemoryBackend::new();
977
978        let result = backend
979            .store_signal(
980                "nonexistent",
981                SignalKind::Cancel,
982                SignalRequest::new(None, None),
983            )
984            .await;
985        assert!(
986            matches!(result, Err(BackendError::NotFound(_))),
987            "should return NotFound for non-existent workflow"
988        );
989    }
990
991    #[tokio::test]
992    async fn test_store_signal_cancel_completed_workflow() {
993        let backend = InMemoryBackend::new();
994        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
995        snapshot.mark_completed(bytes::Bytes::from("result"));
996        backend.save_snapshot(&snapshot).await.unwrap();
997
998        let result = backend
999            .store_signal(
1000                "test-123",
1001                SignalKind::Cancel,
1002                SignalRequest::new(None, None),
1003            )
1004            .await;
1005        assert!(
1006            matches!(result, Err(BackendError::CannotCancel(_))),
1007            "should return CannotCancel for completed workflow"
1008        );
1009    }
1010
1011    #[tokio::test]
1012    async fn test_store_signal_cancel_failed_workflow() {
1013        let backend = InMemoryBackend::new();
1014        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1015        snapshot.mark_failed("Some error".to_string());
1016        backend.save_snapshot(&snapshot).await.unwrap();
1017
1018        let result = backend
1019            .store_signal(
1020                "test-123",
1021                SignalKind::Cancel,
1022                SignalRequest::new(None, None),
1023            )
1024            .await;
1025        assert!(
1026            matches!(result, Err(BackendError::CannotCancel(_))),
1027            "should return CannotCancel for failed workflow"
1028        );
1029    }
1030
1031    #[tokio::test]
1032    async fn test_store_signal_cancel_already_cancelled_idempotent() {
1033        let backend = InMemoryBackend::new();
1034        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1035        snapshot.mark_cancelled(Some("First cancel".to_string()), None, None);
1036        backend.save_snapshot(&snapshot).await.unwrap();
1037
1038        let result = backend
1039            .store_signal(
1040                "test-123",
1041                SignalKind::Cancel,
1042                SignalRequest::new(Some("Second cancel".to_string()), None),
1043            )
1044            .await;
1045        assert!(
1046            result.is_ok(),
1047            "cancelling already-cancelled workflow should be idempotent"
1048        );
1049    }
1050
1051    #[tokio::test]
1052    async fn test_get_signal_cancel_none() {
1053        let backend = InMemoryBackend::new();
1054
1055        let result = backend
1056            .get_signal("test-123", SignalKind::Cancel)
1057            .await
1058            .unwrap();
1059        assert!(
1060            result.is_none(),
1061            "should return None when no cancellation signal exists"
1062        );
1063    }
1064
1065    #[tokio::test]
1066    async fn test_clear_signal_cancel() {
1067        let backend = InMemoryBackend::new();
1068        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1069        backend.save_snapshot(&snapshot).await.unwrap();
1070
1071        backend
1072            .store_signal(
1073                "test-123",
1074                SignalKind::Cancel,
1075                SignalRequest::new(Some("Test".to_string()), None),
1076            )
1077            .await
1078            .unwrap();
1079
1080        assert!(
1081            backend
1082                .get_signal("test-123", SignalKind::Cancel)
1083                .await
1084                .unwrap()
1085                .is_some(),
1086            "cancel signal should exist before clearing"
1087        );
1088
1089        backend
1090            .clear_signal("test-123", SignalKind::Cancel)
1091            .await
1092            .unwrap();
1093
1094        assert!(
1095            backend
1096                .get_signal("test-123", SignalKind::Cancel)
1097                .await
1098                .unwrap()
1099                .is_none(),
1100            "cancel signal should be gone after clearing"
1101        );
1102    }
1103
1104    #[tokio::test]
1105    async fn test_store_signal_pause_completed_workflow() {
1106        let backend = InMemoryBackend::new();
1107        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1108        snapshot.mark_completed(bytes::Bytes::from("result"));
1109        backend.save_snapshot(&snapshot).await.unwrap();
1110
1111        let result = backend
1112            .store_signal(
1113                "test-123",
1114                SignalKind::Pause,
1115                SignalRequest::new(None, None),
1116            )
1117            .await;
1118        assert!(
1119            matches!(result, Err(BackendError::CannotPause(_))),
1120            "should return CannotPause for completed workflow"
1121        );
1122    }
1123
1124    #[tokio::test]
1125    async fn test_store_signal_pause_failed_workflow() {
1126        let backend = InMemoryBackend::new();
1127        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1128        snapshot.mark_failed("Some error".to_string());
1129        backend.save_snapshot(&snapshot).await.unwrap();
1130
1131        let result = backend
1132            .store_signal(
1133                "test-123",
1134                SignalKind::Pause,
1135                SignalRequest::new(None, None),
1136            )
1137            .await;
1138        assert!(
1139            matches!(result, Err(BackendError::CannotPause(_))),
1140            "should return CannotPause for failed workflow"
1141        );
1142    }
1143
1144    #[tokio::test]
1145    async fn test_store_signal_pause_cancelled_workflow() {
1146        let backend = InMemoryBackend::new();
1147        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1148        snapshot.mark_cancelled(Some("done".to_string()), None, None);
1149        backend.save_snapshot(&snapshot).await.unwrap();
1150
1151        let result = backend
1152            .store_signal(
1153                "test-123",
1154                SignalKind::Pause,
1155                SignalRequest::new(None, None),
1156            )
1157            .await;
1158        assert!(
1159            matches!(result, Err(BackendError::CannotPause(_))),
1160            "should return CannotPause for cancelled workflow"
1161        );
1162    }
1163
1164    #[tokio::test]
1165    async fn test_store_signal_pause_already_paused_idempotent() {
1166        let backend = InMemoryBackend::new();
1167        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1168        snapshot.mark_paused(&PauseRequest::new(Some("first".to_string()), None));
1169        backend.save_snapshot(&snapshot).await.unwrap();
1170
1171        let result = backend
1172            .store_signal(
1173                "test-123",
1174                SignalKind::Pause,
1175                SignalRequest::new(Some("second".to_string()), None),
1176            )
1177            .await;
1178        assert!(
1179            result.is_ok(),
1180            "pausing already-paused workflow should be idempotent"
1181        );
1182    }
1183
1184    #[tokio::test]
1185    async fn test_store_signal_pause_not_found() {
1186        let backend = InMemoryBackend::new();
1187        let result = backend
1188            .store_signal(
1189                "nonexistent",
1190                SignalKind::Pause,
1191                SignalRequest::new(None, None),
1192            )
1193            .await;
1194        assert!(
1195            matches!(result, Err(BackendError::NotFound(_))),
1196            "should return NotFound for non-existent workflow"
1197        );
1198    }
1199
1200    #[tokio::test]
1201    async fn test_check_and_cancel_success() {
1202        let backend = InMemoryBackend::new();
1203        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1204        backend.save_snapshot(&snapshot).await.unwrap();
1205
1206        backend
1207            .store_signal(
1208                "test-123",
1209                SignalKind::Cancel,
1210                SignalRequest::new(Some("Timeout".to_string()), Some("system".to_string())),
1211            )
1212            .await
1213            .unwrap();
1214
1215        let result = backend
1216            .check_and_cancel("test-123", Some("task-1"))
1217            .await
1218            .unwrap();
1219        assert!(
1220            result,
1221            "check_and_cancel should return true when cancellation pending"
1222        );
1223
1224        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1225        assert!(
1226            snapshot.state.is_cancelled(),
1227            "workflow should be in cancelled state"
1228        );
1229
1230        let WorkflowSnapshotState::Cancelled {
1231            reason,
1232            cancelled_by,
1233            interrupted_at_task,
1234            ..
1235        } = &snapshot.state
1236        else {
1237            panic!("Expected Cancelled state");
1238        };
1239        assert_eq!(reason, &Some("Timeout".to_string()));
1240        assert_eq!(cancelled_by, &Some("system".to_string()));
1241        assert_eq!(interrupted_at_task, &Some("task-1".to_string()));
1242
1243        assert!(
1244            backend
1245                .get_signal("test-123", SignalKind::Cancel)
1246                .await
1247                .unwrap()
1248                .is_none(),
1249            "cancel signal should be cleared after check_and_cancel"
1250        );
1251    }
1252
1253    #[tokio::test]
1254    async fn test_check_and_cancel_no_request() {
1255        let backend = InMemoryBackend::new();
1256        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1257        backend.save_snapshot(&snapshot).await.unwrap();
1258
1259        let result = backend.check_and_cancel("test-123", None).await.unwrap();
1260        assert!(
1261            !result,
1262            "check_and_cancel should return false when no cancellation pending"
1263        );
1264
1265        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1266        assert!(
1267            snapshot.state.is_in_progress(),
1268            "workflow should still be in progress"
1269        );
1270    }
1271
1272    #[tokio::test]
1273    async fn test_check_and_cancel_not_in_progress() {
1274        let backend = InMemoryBackend::new();
1275        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1276        snapshot.mark_completed(bytes::Bytes::from("done"));
1277        backend.save_snapshot(&snapshot).await.unwrap();
1278
1279        // Add a cancel signal directly (bypassing state check)
1280        {
1281            let mut signals = backend.signals.write().unwrap();
1282            signals
1283                .entry("test-123".to_string())
1284                .or_default()
1285                .insert(SignalKind::Cancel, SignalRequest::new(None, None));
1286        }
1287
1288        let result = backend.check_and_cancel("test-123", None).await.unwrap();
1289        assert!(
1290            !result,
1291            "check_and_cancel should return false for non-in-progress workflow"
1292        );
1293
1294        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1295        assert!(
1296            snapshot.state.is_completed(),
1297            "workflow should still be completed"
1298        );
1299    }
1300
1301    #[tokio::test]
1302    async fn test_find_available_tasks_skips_cancelled_workflows() {
1303        let backend = InMemoryBackend::new();
1304
1305        let mut snapshot1 = WorkflowSnapshot::new("workflow-1".to_string(), "hash-abc".to_string());
1306        snapshot1.update_position(ExecutionPosition::AtTask {
1307            task_id: "task-1".to_string(),
1308        });
1309        backend.save_snapshot(&snapshot1).await.unwrap();
1310
1311        let mut snapshot2 = WorkflowSnapshot::new("workflow-2".to_string(), "hash-abc".to_string());
1312        snapshot2.update_position(ExecutionPosition::AtTask {
1313            task_id: "task-2".to_string(),
1314        });
1315        backend.save_snapshot(&snapshot2).await.unwrap();
1316
1317        backend
1318            .store_signal(
1319                "workflow-1",
1320                SignalKind::Cancel,
1321                SignalRequest::new(None, None),
1322            )
1323            .await
1324            .unwrap();
1325
1326        let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1327
1328        assert!(
1329            !tasks.iter().any(|t| t.instance_id == "workflow-1"),
1330            "workflow with pending cancellation should be skipped"
1331        );
1332    }
1333
1334    // ========================================================================
1335    // check_and_pause tests
1336    // ========================================================================
1337
1338    #[tokio::test]
1339    async fn test_check_and_pause_success() {
1340        let backend = InMemoryBackend::new();
1341        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1342        backend.save_snapshot(&snapshot).await.unwrap();
1343
1344        backend
1345            .store_signal(
1346                "test-123",
1347                SignalKind::Pause,
1348                SignalRequest::new(Some("maintenance".to_string()), Some("ops".to_string())),
1349            )
1350            .await
1351            .unwrap();
1352
1353        let result = backend.check_and_pause("test-123").await.unwrap();
1354        assert!(
1355            result,
1356            "check_and_pause should return true when pause pending"
1357        );
1358
1359        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1360        assert!(snapshot.state.is_paused(), "workflow should be paused");
1361
1362        let WorkflowSnapshotState::Paused {
1363            reason, paused_by, ..
1364        } = &snapshot.state
1365        else {
1366            panic!("Expected Paused state");
1367        };
1368        assert_eq!(reason, &Some("maintenance".to_string()));
1369        assert_eq!(paused_by, &Some("ops".to_string()));
1370
1371        assert!(
1372            backend
1373                .get_signal("test-123", SignalKind::Pause)
1374                .await
1375                .unwrap()
1376                .is_none(),
1377            "pause signal should be cleared after check_and_pause"
1378        );
1379    }
1380
1381    #[tokio::test]
1382    async fn test_check_and_pause_no_request() {
1383        let backend = InMemoryBackend::new();
1384        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1385        backend.save_snapshot(&snapshot).await.unwrap();
1386
1387        let result = backend.check_and_pause("test-123").await.unwrap();
1388        assert!(
1389            !result,
1390            "check_and_pause should return false when no pause pending"
1391        );
1392
1393        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1394        assert!(
1395            snapshot.state.is_in_progress(),
1396            "workflow should still be in progress"
1397        );
1398    }
1399
1400    #[tokio::test]
1401    async fn test_check_and_pause_not_in_progress() {
1402        let backend = InMemoryBackend::new();
1403        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1404        snapshot.mark_completed(bytes::Bytes::from("done"));
1405        backend.save_snapshot(&snapshot).await.unwrap();
1406
1407        // Add a pause signal directly (bypassing state check)
1408        {
1409            let mut signals = backend.signals.write().unwrap();
1410            signals
1411                .entry("test-123".to_string())
1412                .or_default()
1413                .insert(SignalKind::Pause, SignalRequest::new(None, None));
1414        }
1415
1416        let result = backend.check_and_pause("test-123").await.unwrap();
1417        assert!(
1418            !result,
1419            "check_and_pause should return false for non-in-progress workflow"
1420        );
1421
1422        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1423        assert!(
1424            snapshot.state.is_completed(),
1425            "workflow should still be completed"
1426        );
1427    }
1428
1429    #[tokio::test]
1430    async fn test_check_and_pause_preserves_position() {
1431        let backend = InMemoryBackend::new();
1432        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1433        snapshot.update_position(ExecutionPosition::AtTask {
1434            task_id: "task-3".to_string(),
1435        });
1436        snapshot.mark_task_completed("task-1".to_string(), bytes::Bytes::from("out1"));
1437        snapshot.mark_task_completed("task-2".to_string(), bytes::Bytes::from("out2"));
1438        backend.save_snapshot(&snapshot).await.unwrap();
1439
1440        backend
1441            .store_signal(
1442                "test-123",
1443                SignalKind::Pause,
1444                SignalRequest::new(None, None),
1445            )
1446            .await
1447            .unwrap();
1448
1449        backend.check_and_pause("test-123").await.unwrap();
1450
1451        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1452        let WorkflowSnapshotState::Paused {
1453            completed_tasks,
1454            position,
1455            last_completed_task_id,
1456            ..
1457        } = &snapshot.state
1458        else {
1459            panic!("Expected Paused state");
1460        };
1461
1462        assert_eq!(completed_tasks.len(), 2);
1463        assert!(completed_tasks.contains_key("task-1"));
1464        assert!(completed_tasks.contains_key("task-2"));
1465        assert!(matches!(
1466            position,
1467            ExecutionPosition::AtTask { task_id } if task_id == "task-3"
1468        ));
1469        assert_eq!(last_completed_task_id, &Some("task-2".to_string()));
1470    }
1471
1472    // ========================================================================
1473    // unpause tests
1474    // ========================================================================
1475
1476    #[tokio::test]
1477    async fn test_unpause_success() {
1478        let backend = InMemoryBackend::new();
1479        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1480        snapshot.update_position(ExecutionPosition::AtTask {
1481            task_id: "task-2".to_string(),
1482        });
1483        snapshot.mark_task_completed("task-1".to_string(), bytes::Bytes::from("out1"));
1484        snapshot.mark_paused(&PauseRequest::new(
1485            Some("maintenance".to_string()),
1486            Some("ops".to_string()),
1487        ));
1488        backend.save_snapshot(&snapshot).await.unwrap();
1489
1490        let result = backend.unpause("test-123").await.unwrap();
1491
1492        assert!(
1493            result.state.is_in_progress(),
1494            "unpaused workflow should be in progress"
1495        );
1496
1497        // Verify position and tasks were restored
1498        let WorkflowSnapshotState::InProgress {
1499            position,
1500            completed_tasks,
1501            last_completed_task_id,
1502        } = &result.state
1503        else {
1504            panic!("Expected InProgress state");
1505        };
1506        assert!(matches!(
1507            position,
1508            ExecutionPosition::AtTask { task_id } if task_id == "task-2"
1509        ));
1510        assert!(completed_tasks.contains_key("task-1"));
1511        assert_eq!(last_completed_task_id, &Some("task-1".to_string()));
1512
1513        // Verify persisted state matches
1514        let loaded = backend.load_snapshot("test-123").await.unwrap();
1515        assert!(loaded.state.is_in_progress());
1516    }
1517
1518    #[tokio::test]
1519    async fn test_unpause_not_paused_errors() {
1520        let backend = InMemoryBackend::new();
1521        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1522        backend.save_snapshot(&snapshot).await.unwrap();
1523
1524        let result = backend.unpause("test-123").await;
1525        assert!(
1526            matches!(result, Err(BackendError::CannotPause(_))),
1527            "unpause on in-progress workflow should error"
1528        );
1529    }
1530
1531    #[tokio::test]
1532    async fn test_unpause_completed_errors() {
1533        let backend = InMemoryBackend::new();
1534        let mut snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1535        snapshot.mark_completed(bytes::Bytes::from("done"));
1536        backend.save_snapshot(&snapshot).await.unwrap();
1537
1538        let result = backend.unpause("test-123").await;
1539        assert!(
1540            matches!(result, Err(BackendError::CannotPause(_))),
1541            "unpause on completed workflow should error"
1542        );
1543    }
1544
1545    #[tokio::test]
1546    async fn test_unpause_not_found() {
1547        let backend = InMemoryBackend::new();
1548        let result = backend.unpause("nonexistent").await;
1549        assert!(matches!(result, Err(BackendError::NotFound(_))));
1550    }
1551
1552    // ========================================================================
1553    // Concurrent signals tests
1554    // ========================================================================
1555
1556    #[tokio::test]
1557    async fn test_cancel_and_pause_simultaneously_cancel_wins() {
1558        let backend = InMemoryBackend::new();
1559        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1560        backend.save_snapshot(&snapshot).await.unwrap();
1561
1562        // Store both signals
1563        backend
1564            .store_signal(
1565                "test-123",
1566                SignalKind::Cancel,
1567                SignalRequest::new(Some("cancel reason".to_string()), None),
1568            )
1569            .await
1570            .unwrap();
1571        backend
1572            .store_signal(
1573                "test-123",
1574                SignalKind::Pause,
1575                SignalRequest::new(Some("pause reason".to_string()), None),
1576            )
1577            .await
1578            .unwrap();
1579
1580        // check_and_cancel should process the cancel signal
1581        let cancelled = backend
1582            .check_and_cancel("test-123", Some("task-1"))
1583            .await
1584            .unwrap();
1585        assert!(cancelled, "cancel should succeed");
1586
1587        // Now check_and_pause — workflow is already cancelled (not in progress)
1588        let paused = backend.check_and_pause("test-123").await.unwrap();
1589        assert!(
1590            !paused,
1591            "pause should return false since workflow is already cancelled"
1592        );
1593
1594        let snapshot = backend.load_snapshot("test-123").await.unwrap();
1595        assert!(snapshot.state.is_cancelled());
1596    }
1597
1598    #[tokio::test]
1599    async fn test_cancel_signal_independent_of_pause_signal() {
1600        let backend = InMemoryBackend::new();
1601        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1602        backend.save_snapshot(&snapshot).await.unwrap();
1603
1604        // Store both signals
1605        backend
1606            .store_signal(
1607                "test-123",
1608                SignalKind::Cancel,
1609                SignalRequest::new(Some("cancel".to_string()), None),
1610            )
1611            .await
1612            .unwrap();
1613        backend
1614            .store_signal(
1615                "test-123",
1616                SignalKind::Pause,
1617                SignalRequest::new(Some("pause".to_string()), None),
1618            )
1619            .await
1620            .unwrap();
1621
1622        // Clear only cancel
1623        backend
1624            .clear_signal("test-123", SignalKind::Cancel)
1625            .await
1626            .unwrap();
1627
1628        // Cancel should be gone, pause should remain
1629        assert!(
1630            backend
1631                .get_signal("test-123", SignalKind::Cancel)
1632                .await
1633                .unwrap()
1634                .is_none()
1635        );
1636        assert!(
1637            backend
1638                .get_signal("test-123", SignalKind::Pause)
1639                .await
1640                .unwrap()
1641                .is_some()
1642        );
1643    }
1644
1645    // ========================================================================
1646    // find_available_tasks + pause signal
1647    // ========================================================================
1648
1649    #[tokio::test]
1650    async fn test_find_available_tasks_skips_paused_workflows() {
1651        let backend = InMemoryBackend::new();
1652
1653        let mut snapshot1 = WorkflowSnapshot::with_initial_input(
1654            "workflow-1".to_string(),
1655            "hash-abc".to_string(),
1656            bytes::Bytes::from(vec![1]),
1657        );
1658        snapshot1.update_position(ExecutionPosition::AtTask {
1659            task_id: "task-1".to_string(),
1660        });
1661        backend.save_snapshot(&snapshot1).await.unwrap();
1662
1663        let mut snapshot2 = WorkflowSnapshot::with_initial_input(
1664            "workflow-2".to_string(),
1665            "hash-abc".to_string(),
1666            bytes::Bytes::from(vec![2]),
1667        );
1668        snapshot2.update_position(ExecutionPosition::AtTask {
1669            task_id: "task-2".to_string(),
1670        });
1671        backend.save_snapshot(&snapshot2).await.unwrap();
1672
1673        // Pause workflow-1
1674        backend
1675            .store_signal(
1676                "workflow-1",
1677                SignalKind::Pause,
1678                SignalRequest::new(None, None),
1679            )
1680            .await
1681            .unwrap();
1682
1683        let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1684
1685        assert!(
1686            !tasks.iter().any(|t| t.instance_id == "workflow-1"),
1687            "workflow with pending pause should be skipped"
1688        );
1689        assert!(
1690            tasks.iter().any(|t| t.instance_id == "workflow-2"),
1691            "workflow without signals should be available"
1692        );
1693    }
1694
1695    // ========================================================================
1696    // Orphaned signals
1697    // ========================================================================
1698
1699    #[tokio::test]
1700    async fn test_delete_snapshot_leaves_orphaned_signals() {
1701        let backend = InMemoryBackend::new();
1702        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1703        backend.save_snapshot(&snapshot).await.unwrap();
1704
1705        backend
1706            .store_signal(
1707                "test-123",
1708                SignalKind::Cancel,
1709                SignalRequest::new(Some("reason".to_string()), None),
1710            )
1711            .await
1712            .unwrap();
1713
1714        // Delete the snapshot
1715        backend.delete_snapshot("test-123").await.unwrap();
1716
1717        // Signal is still there (orphaned) — this documents current behavior
1718        let signal = backend
1719            .get_signal("test-123", SignalKind::Cancel)
1720            .await
1721            .unwrap();
1722        assert!(
1723            signal.is_some(),
1724            "signal persists after snapshot deletion (orphaned)"
1725        );
1726    }
1727
1728    #[tokio::test]
1729    async fn test_store_signal_overwrites_previous() {
1730        let backend = InMemoryBackend::new();
1731        let snapshot = WorkflowSnapshot::new("test-123".to_string(), "hash-abc".to_string());
1732        backend.save_snapshot(&snapshot).await.unwrap();
1733
1734        backend
1735            .store_signal(
1736                "test-123",
1737                SignalKind::Cancel,
1738                SignalRequest::new(Some("first".to_string()), None),
1739            )
1740            .await
1741            .unwrap();
1742        backend
1743            .store_signal(
1744                "test-123",
1745                SignalKind::Cancel,
1746                SignalRequest::new(Some("second".to_string()), None),
1747            )
1748            .await
1749            .unwrap();
1750
1751        let signal = backend
1752            .get_signal("test-123", SignalKind::Cancel)
1753            .await
1754            .unwrap()
1755            .unwrap();
1756        assert_eq!(
1757            signal.reason,
1758            Some("second".to_string()),
1759            "latest signal should overwrite previous"
1760        );
1761    }
1762
1763    // ========================================================================
1764    // Delay tests
1765    // ========================================================================
1766
1767    #[tokio::test]
1768    async fn test_find_available_tasks_skips_unexpired_delay() {
1769        let backend = InMemoryBackend::new();
1770
1771        let mut snapshot = WorkflowSnapshot::with_initial_input(
1772            "workflow-1".to_string(),
1773            "hash-abc".to_string(),
1774            bytes::Bytes::from(vec![42]),
1775        );
1776        // Park at a delay that expires in the future
1777        let wake_at = Utc::now() + chrono::Duration::hours(1);
1778        snapshot.update_position(ExecutionPosition::AtDelay {
1779            delay_id: "wait_1h".to_string(),
1780            entered_at: Utc::now(),
1781            wake_at,
1782            next_task_id: Some("next_step".to_string()),
1783        });
1784        snapshot.mark_task_completed("wait_1h".to_string(), bytes::Bytes::from(vec![42]));
1785        backend.save_snapshot(&snapshot).await.unwrap();
1786
1787        let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1788        assert!(
1789            tasks.is_empty(),
1790            "workflow at unexpired delay should not appear in available tasks"
1791        );
1792    }
1793
1794    #[tokio::test]
1795    async fn test_find_available_tasks_advances_expired_delay() {
1796        let backend = InMemoryBackend::new();
1797
1798        let mut snapshot = WorkflowSnapshot::with_initial_input(
1799            "workflow-1".to_string(),
1800            "hash-abc".to_string(),
1801            bytes::Bytes::from(vec![42]),
1802        );
1803        // Park at a delay that has already expired
1804        let wake_at = Utc::now() - chrono::Duration::seconds(1);
1805        snapshot.update_position(ExecutionPosition::AtDelay {
1806            delay_id: "wait_done".to_string(),
1807            entered_at: Utc::now() - chrono::Duration::seconds(2),
1808            wake_at,
1809            next_task_id: Some("process".to_string()),
1810        });
1811        snapshot.mark_task_completed("wait_done".to_string(), bytes::Bytes::from(vec![42]));
1812        backend.save_snapshot(&snapshot).await.unwrap();
1813
1814        let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1815
1816        // The delay has expired, so the position should have been advanced to "process"
1817        assert_eq!(tasks.len(), 1);
1818        assert_eq!(tasks[0].instance_id, "workflow-1");
1819        assert_eq!(tasks[0].task_id, "process");
1820
1821        // Verify position was advanced in the snapshot
1822        let loaded = backend.load_snapshot("workflow-1").await.unwrap();
1823        match &loaded.state {
1824            WorkflowSnapshotState::InProgress {
1825                position: ExecutionPosition::AtTask { task_id },
1826                ..
1827            } => {
1828                assert_eq!(task_id, "process");
1829            }
1830            other => panic!("Expected AtTask position, got {other:?}"),
1831        }
1832    }
1833
1834    #[tokio::test]
1835    async fn test_find_available_tasks_completes_expired_delay_last_node() {
1836        let backend = InMemoryBackend::new();
1837
1838        let mut snapshot = WorkflowSnapshot::with_initial_input(
1839            "workflow-1".to_string(),
1840            "hash-abc".to_string(),
1841            bytes::Bytes::from(vec![42]),
1842        );
1843        // Park at a delay that has expired AND has no next task (delay is last node)
1844        let wake_at = Utc::now() - chrono::Duration::seconds(1);
1845        snapshot.update_position(ExecutionPosition::AtDelay {
1846            delay_id: "final_wait".to_string(),
1847            entered_at: Utc::now() - chrono::Duration::seconds(2),
1848            wake_at,
1849            next_task_id: None,
1850        });
1851        snapshot.mark_task_completed("final_wait".to_string(), bytes::Bytes::from(vec![42]));
1852        backend.save_snapshot(&snapshot).await.unwrap();
1853
1854        let tasks = backend.find_available_tasks("worker-1", 10).await.unwrap();
1855
1856        // No available tasks — the workflow should have been marked completed
1857        assert!(
1858            tasks.is_empty(),
1859            "completed workflow should not appear in available tasks"
1860        );
1861
1862        // Verify workflow was marked completed
1863        let loaded = backend.load_snapshot("workflow-1").await.unwrap();
1864        assert!(
1865            loaded.state.is_completed(),
1866            "workflow should be completed when delay is last node and expired"
1867        );
1868    }
1869
1870    // ── Event queue (send_event / consume_event) ────────────────────────
1871
1872    #[tokio::test]
1873    async fn test_send_and_consume_event_fifo() {
1874        let backend = InMemoryBackend::new();
1875
1876        // Send two events for the same signal
1877        backend
1878            .send_event("wf-1", "approval", bytes::Bytes::from("first"))
1879            .await
1880            .unwrap();
1881        backend
1882            .send_event("wf-1", "approval", bytes::Bytes::from("second"))
1883            .await
1884            .unwrap();
1885
1886        // Consume should return FIFO order
1887        let first = backend.consume_event("wf-1", "approval").await.unwrap();
1888        assert_eq!(first.as_deref(), Some(b"first".as_slice()));
1889
1890        let second = backend.consume_event("wf-1", "approval").await.unwrap();
1891        assert_eq!(second.as_deref(), Some(b"second".as_slice()));
1892
1893        // Queue is now empty
1894        let none = backend.consume_event("wf-1", "approval").await.unwrap();
1895        assert!(none.is_none());
1896    }
1897
1898    #[tokio::test]
1899    async fn test_consume_event_empty_returns_none() {
1900        let backend = InMemoryBackend::new();
1901        let result = backend.consume_event("wf-1", "nonexistent").await.unwrap();
1902        assert!(result.is_none());
1903    }
1904
1905    #[tokio::test]
1906    async fn test_events_are_isolated_by_signal_name() {
1907        let backend = InMemoryBackend::new();
1908
1909        backend
1910            .send_event("wf-1", "sig_a", bytes::Bytes::from("a_payload"))
1911            .await
1912            .unwrap();
1913        backend
1914            .send_event("wf-1", "sig_b", bytes::Bytes::from("b_payload"))
1915            .await
1916            .unwrap();
1917
1918        // Consuming sig_a should not affect sig_b
1919        let a = backend.consume_event("wf-1", "sig_a").await.unwrap();
1920        assert_eq!(a.as_deref(), Some(b"a_payload".as_slice()));
1921
1922        let b = backend.consume_event("wf-1", "sig_b").await.unwrap();
1923        assert_eq!(b.as_deref(), Some(b"b_payload".as_slice()));
1924    }
1925}