Skip to main content

symbi_runtime/reasoning/
saga.rs

1//! Saga orchestrator for multi-step tool sequences
2//!
3//! Provides forward execution with backward compensation on failure.
4//! Tool actions are classified as ReadOnly, Compensatable, or Final,
5//! enabling automatic rollback when a sequence fails partway through.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12/// Classification of a saga step's side effects.
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub enum StepClassification {
15    /// No side effects. Safe to skip during compensation.
16    ReadOnly,
17    /// Has side effects but can be reversed via a compensation action.
18    Compensatable,
19    /// Irreversible. Only permitted after all preceding steps succeed.
20    Final,
21}
22
23/// A single step in a saga.
24#[derive(Debug, Clone)]
25pub struct SagaStep {
26    /// Step name for logging and audit.
27    pub name: String,
28    /// Side-effect classification.
29    pub classification: StepClassification,
30    /// The forward action to execute.
31    pub action: SagaAction,
32    /// The compensation action (only for Compensatable steps).
33    pub compensation: Option<SagaAction>,
34}
35
36/// An action in a saga (either forward or compensation).
37#[derive(Debug, Clone)]
38pub struct SagaAction {
39    /// Tool name to invoke.
40    pub tool_name: String,
41    /// Arguments for the tool.
42    pub arguments: serde_json::Value,
43}
44
45/// Result of a single saga step execution.
46#[derive(Debug, Clone)]
47pub struct StepResult {
48    /// Step name.
49    pub name: String,
50    /// Whether the step succeeded.
51    pub success: bool,
52    /// Output from the step.
53    pub output: String,
54    /// Error message if failed.
55    pub error: Option<String>,
56}
57
58/// Overall saga execution result.
59#[derive(Debug, Clone)]
60pub struct SagaResult {
61    /// Whether the entire saga succeeded.
62    pub success: bool,
63    /// Results from each step (forward execution).
64    pub step_results: Vec<StepResult>,
65    /// Results from compensation steps (if saga failed).
66    pub compensation_results: Vec<StepResult>,
67    /// Summary of what happened.
68    pub summary: String,
69}
70
71/// Unique key for idempotent execution of saga steps.
72///
73/// Enables safe retries of Final steps: if a step has already been executed
74/// with a given key, re-execution is skipped.
75#[derive(Debug, Clone, PartialEq, Eq, Hash)]
76pub struct IdempotencyKey(pub String);
77
78impl IdempotencyKey {
79    /// Generate a new unique idempotency key.
80    pub fn generate() -> Self {
81        Self(uuid::Uuid::new_v4().to_string())
82    }
83}
84
85/// Persistence hook called before Final steps execute.
86///
87/// Records intent before execution and outcome after, enabling recovery
88/// of Final steps that were started but not confirmed (e.g., process crash).
89#[async_trait::async_trait]
90pub trait SagaCheckpoint: Send + Sync {
91    /// Record that we intend to execute this Final step.
92    async fn record_intent(&self, step: &SagaStep, key: &IdempotencyKey) -> Result<(), SagaError>;
93
94    /// Record the outcome of a Final step execution.
95    async fn record_outcome(&self, key: &IdempotencyKey, success: bool) -> Result<(), SagaError>;
96
97    /// Return all intents that were recorded but have no corresponding outcome.
98    async fn pending_intents(&self) -> Result<Vec<(SagaStep, IdempotencyKey)>, SagaError>;
99}
100
101/// In-memory checkpoint implementation.
102///
103/// Suitable for testing and single-process use. Pairs with the
104/// `MemoryJournalStorage` philosophy — no external dependencies required.
105pub struct InMemoryCheckpoint {
106    intents: Mutex<Vec<(SagaStep, IdempotencyKey)>>,
107    outcomes: Mutex<HashMap<String, bool>>,
108}
109
110impl Default for InMemoryCheckpoint {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl InMemoryCheckpoint {
117    pub fn new() -> Self {
118        Self {
119            intents: Mutex::new(Vec::new()),
120            outcomes: Mutex::new(HashMap::new()),
121        }
122    }
123}
124
125#[async_trait::async_trait]
126impl SagaCheckpoint for InMemoryCheckpoint {
127    async fn record_intent(&self, step: &SagaStep, key: &IdempotencyKey) -> Result<(), SagaError> {
128        self.intents.lock().await.push((step.clone(), key.clone()));
129        Ok(())
130    }
131
132    async fn record_outcome(&self, key: &IdempotencyKey, success: bool) -> Result<(), SagaError> {
133        self.outcomes.lock().await.insert(key.0.clone(), success);
134        Ok(())
135    }
136
137    async fn pending_intents(&self) -> Result<Vec<(SagaStep, IdempotencyKey)>, SagaError> {
138        let intents = self.intents.lock().await;
139        let outcomes = self.outcomes.lock().await;
140        Ok(intents
141            .iter()
142            .filter(|(_, key)| !outcomes.contains_key(&key.0))
143            .cloned()
144            .collect())
145    }
146}
147
148/// Orchestrates saga execution with compensation.
149pub struct SagaOrchestrator {
150    steps: Vec<SagaStep>,
151}
152
153impl SagaOrchestrator {
154    /// Create a new saga with the given steps.
155    ///
156    /// Validates that Final steps only appear after all Compensatable steps.
157    pub fn new(steps: Vec<SagaStep>) -> Result<Self, SagaError> {
158        // Validate: no Compensatable or ReadOnly steps after a Final step
159        let mut seen_final = false;
160        for step in &steps {
161            if seen_final && step.classification != StepClassification::Final {
162                return Err(SagaError::InvalidStepOrder {
163                    step: step.name.clone(),
164                    reason: "Non-final steps cannot appear after a Final step".into(),
165                });
166            }
167            if step.classification == StepClassification::Final {
168                seen_final = true;
169            }
170        }
171
172        // Validate: Compensatable steps must have compensation actions
173        for step in &steps {
174            if step.classification == StepClassification::Compensatable
175                && step.compensation.is_none()
176            {
177                return Err(SagaError::MissingCompensation {
178                    step: step.name.clone(),
179                });
180            }
181        }
182
183        Ok(Self { steps })
184    }
185
186    /// Execute the saga using the provided executor function.
187    ///
188    /// The executor takes (tool_name, arguments) and returns (success, output).
189    pub async fn execute<F, Fut>(&self, executor: F) -> SagaResult
190    where
191        F: Fn(String, serde_json::Value) -> Fut,
192        Fut: std::future::Future<Output = Result<String, String>>,
193    {
194        self.execute_inner(&executor, None).await
195    }
196
197    /// Execute the saga with a checkpoint for Final step durability.
198    ///
199    /// Before each Final step, records intent via the checkpoint. After
200    /// execution, records the outcome. On crash, `recover()` can identify
201    /// Final steps that were intended but not confirmed.
202    pub async fn execute_with_checkpoint<F, Fut>(
203        &self,
204        executor: F,
205        checkpoint: Arc<dyn SagaCheckpoint>,
206    ) -> SagaResult
207    where
208        F: Fn(String, serde_json::Value) -> Fut,
209        Fut: std::future::Future<Output = Result<String, String>>,
210    {
211        self.execute_inner(&executor, Some(checkpoint)).await
212    }
213
214    async fn execute_inner<F, Fut>(
215        &self,
216        executor: &F,
217        checkpoint: Option<Arc<dyn SagaCheckpoint>>,
218    ) -> SagaResult
219    where
220        F: Fn(String, serde_json::Value) -> Fut,
221        Fut: std::future::Future<Output = Result<String, String>>,
222    {
223        let mut step_results = Vec::new();
224        let mut completed_compensatable: Vec<&SagaStep> = Vec::new();
225
226        for step in &self.steps {
227            // For Final steps with a checkpoint, record intent before execution
228            let idempotency_key = if step.classification == StepClassification::Final {
229                if let Some(cp) = &checkpoint {
230                    let key = IdempotencyKey::generate();
231                    if let Err(e) = cp.record_intent(step, &key).await {
232                        return SagaResult {
233                            success: false,
234                            step_results,
235                            compensation_results: self
236                                .compensate(&completed_compensatable, executor)
237                                .await,
238                            summary: format!(
239                                "Failed to record intent for Final step '{}': {}",
240                                step.name, e
241                            ),
242                        };
243                    }
244                    Some(key)
245                } else {
246                    None
247                }
248            } else {
249                None
250            };
251
252            let result =
253                executor(step.action.tool_name.clone(), step.action.arguments.clone()).await;
254
255            match result {
256                Ok(output) => {
257                    // Record successful outcome for Final steps
258                    if let Some(key) = &idempotency_key {
259                        if let Some(cp) = &checkpoint {
260                            let _ = cp.record_outcome(key, true).await;
261                        }
262                    }
263
264                    step_results.push(StepResult {
265                        name: step.name.clone(),
266                        success: true,
267                        output,
268                        error: None,
269                    });
270
271                    if step.classification == StepClassification::Compensatable {
272                        completed_compensatable.push(step);
273                    }
274                }
275                Err(error) => {
276                    // Record failed outcome for Final steps
277                    if let Some(key) = &idempotency_key {
278                        if let Some(cp) = &checkpoint {
279                            let _ = cp.record_outcome(key, false).await;
280                        }
281                    }
282
283                    step_results.push(StepResult {
284                        name: step.name.clone(),
285                        success: false,
286                        output: String::new(),
287                        error: Some(error.clone()),
288                    });
289
290                    // Compensate in reverse order
291                    let compensation_results =
292                        self.compensate(&completed_compensatable, executor).await;
293
294                    return SagaResult {
295                        success: false,
296                        step_results,
297                        compensation_results,
298                        summary: format!("Saga failed at step '{}': {}", step.name, error),
299                    };
300                }
301            }
302        }
303
304        SagaResult {
305            success: true,
306            step_results,
307            compensation_results: Vec::new(),
308            summary: "Saga completed successfully".into(),
309        }
310    }
311
312    /// Recover pending Final steps from a checkpoint.
313    ///
314    /// Returns the list of Final steps that were intended but never got an
315    /// outcome recorded (e.g., due to process crash). The caller can then
316    /// decide to re-execute or report them.
317    pub async fn recover<F, Fut>(
318        checkpoint: &dyn SagaCheckpoint,
319        executor: F,
320    ) -> Result<Vec<StepResult>, SagaError>
321    where
322        F: Fn(String, serde_json::Value) -> Fut,
323        Fut: std::future::Future<Output = Result<String, String>>,
324    {
325        let pending = checkpoint.pending_intents().await?;
326        let mut results = Vec::new();
327
328        for (step, key) in &pending {
329            let result =
330                executor(step.action.tool_name.clone(), step.action.arguments.clone()).await;
331
332            let success = result.is_ok();
333            let _ = checkpoint.record_outcome(key, success).await;
334
335            results.push(StepResult {
336                name: format!("recover:{}", step.name),
337                success,
338                output: result.as_ref().cloned().unwrap_or_default(),
339                error: result.err(),
340            });
341        }
342
343        Ok(results)
344    }
345
346    async fn compensate<F, Fut>(&self, completed: &[&SagaStep], executor: &F) -> Vec<StepResult>
347    where
348        F: Fn(String, serde_json::Value) -> Fut,
349        Fut: std::future::Future<Output = Result<String, String>>,
350    {
351        let mut results = Vec::new();
352
353        // Compensate in reverse order
354        for step in completed.iter().rev() {
355            if let Some(compensation) = &step.compensation {
356                let result = executor(
357                    compensation.tool_name.clone(),
358                    compensation.arguments.clone(),
359                )
360                .await;
361
362                results.push(StepResult {
363                    name: format!("compensate:{}", step.name),
364                    success: result.is_ok(),
365                    output: result.as_ref().cloned().unwrap_or_default(),
366                    error: result.err(),
367                });
368            }
369        }
370
371        results
372    }
373
374    /// Get the step count.
375    pub fn step_count(&self) -> usize {
376        self.steps.len()
377    }
378
379    /// Get steps by classification.
380    pub fn steps_by_classification(&self) -> HashMap<String, usize> {
381        let mut counts = HashMap::new();
382        for step in &self.steps {
383            let key = format!("{:?}", step.classification);
384            *counts.entry(key).or_insert(0) += 1;
385        }
386        counts
387    }
388}
389
390/// Errors from the saga orchestrator.
391#[derive(Debug, thiserror::Error)]
392pub enum SagaError {
393    #[error("Invalid step order for '{step}': {reason}")]
394    InvalidStepOrder { step: String, reason: String },
395
396    #[error("Compensatable step '{step}' is missing a compensation action")]
397    MissingCompensation { step: String },
398
399    #[error("Checkpoint operation failed: {0}")]
400    CheckpointFailed(String),
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    fn read_step(name: &str) -> SagaStep {
408        SagaStep {
409            name: name.into(),
410            classification: StepClassification::ReadOnly,
411            action: SagaAction {
412                tool_name: "read".into(),
413                arguments: serde_json::json!({"path": name}),
414            },
415            compensation: None,
416        }
417    }
418
419    fn write_step(name: &str) -> SagaStep {
420        SagaStep {
421            name: name.into(),
422            classification: StepClassification::Compensatable,
423            action: SagaAction {
424                tool_name: "write".into(),
425                arguments: serde_json::json!({"path": name, "data": "content"}),
426            },
427            compensation: Some(SagaAction {
428                tool_name: "delete".into(),
429                arguments: serde_json::json!({"path": name}),
430            }),
431        }
432    }
433
434    fn final_step(name: &str) -> SagaStep {
435        SagaStep {
436            name: name.into(),
437            classification: StepClassification::Final,
438            action: SagaAction {
439                tool_name: "publish".into(),
440                arguments: serde_json::json!({"target": name}),
441            },
442            compensation: None,
443        }
444    }
445
446    #[test]
447    fn test_valid_saga_creation() {
448        let saga = SagaOrchestrator::new(vec![
449            read_step("check"),
450            write_step("create"),
451            write_step("update"),
452            final_step("publish"),
453        ]);
454        assert!(saga.is_ok());
455        assert_eq!(saga.unwrap().step_count(), 4);
456    }
457
458    #[test]
459    fn test_invalid_order_non_final_after_final() {
460        let saga = SagaOrchestrator::new(vec![
461            write_step("create"),
462            final_step("publish"),
463            read_step("check"), // Invalid: ReadOnly after Final
464        ]);
465        assert!(saga.is_err());
466    }
467
468    #[test]
469    fn test_missing_compensation() {
470        let step = SagaStep {
471            name: "bad".into(),
472            classification: StepClassification::Compensatable,
473            action: SagaAction {
474                tool_name: "write".into(),
475                arguments: serde_json::json!({}),
476            },
477            compensation: None, // Missing!
478        };
479        let saga = SagaOrchestrator::new(vec![step]);
480        assert!(saga.is_err());
481    }
482
483    #[tokio::test]
484    async fn test_saga_all_succeed() {
485        let saga = SagaOrchestrator::new(vec![
486            read_step("check"),
487            write_step("create"),
488            write_step("update"),
489        ])
490        .unwrap();
491
492        let result = saga
493            .execute(|_tool, _args| async { Ok("success".to_string()) })
494            .await;
495
496        assert!(result.success);
497        assert_eq!(result.step_results.len(), 3);
498        assert!(result.compensation_results.is_empty());
499    }
500
501    #[tokio::test]
502    async fn test_saga_fail_at_step_3_compensates_2_and_1() {
503        let saga = SagaOrchestrator::new(vec![
504            read_step("check"),
505            write_step("create"),
506            write_step("update"),
507            write_step("finalize"),
508            final_step("publish"),
509        ])
510        .unwrap();
511
512        let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
513        let cc = call_count.clone();
514
515        let result = saga
516            .execute(move |_tool, _args| {
517                let count = cc.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
518                async move {
519                    if count == 3 {
520                        // Fail at step 4 (finalize)
521                        Err("Connection refused".to_string())
522                    } else {
523                        Ok("ok".to_string())
524                    }
525                }
526            })
527            .await;
528
529        assert!(!result.success);
530        assert_eq!(result.step_results.len(), 4); // 3 succeeded + 1 failed
531        assert!(!result.step_results[3].success);
532
533        // Compensation should run for create and update (in reverse order)
534        assert_eq!(result.compensation_results.len(), 2);
535        assert_eq!(result.compensation_results[0].name, "compensate:update");
536        assert_eq!(result.compensation_results[1].name, "compensate:create");
537    }
538
539    #[tokio::test]
540    async fn test_saga_fail_at_readonly_no_compensation() {
541        let saga = SagaOrchestrator::new(vec![read_step("check")]).unwrap();
542
543        let result = saga
544            .execute(|_tool, _args| async { Err("fail".to_string()) })
545            .await;
546
547        assert!(!result.success);
548        assert!(result.compensation_results.is_empty());
549    }
550
551    #[test]
552    fn test_steps_by_classification() {
553        let saga = SagaOrchestrator::new(vec![
554            read_step("r1"),
555            write_step("w1"),
556            write_step("w2"),
557            final_step("f1"),
558        ])
559        .unwrap();
560
561        let counts = saga.steps_by_classification();
562        assert_eq!(counts.get("ReadOnly"), Some(&1));
563        assert_eq!(counts.get("Compensatable"), Some(&2));
564        assert_eq!(counts.get("Final"), Some(&1));
565    }
566
567    #[tokio::test]
568    async fn test_final_step_with_checkpoint_records_intent_and_outcome() {
569        let checkpoint = Arc::new(InMemoryCheckpoint::new());
570        let saga =
571            SagaOrchestrator::new(vec![write_step("prepare"), final_step("publish")]).unwrap();
572
573        let result = saga
574            .execute_with_checkpoint(
575                |_tool, _args| async { Ok("done".to_string()) },
576                checkpoint.clone(),
577            )
578            .await;
579
580        assert!(result.success);
581
582        // Intent was recorded for the Final step
583        let intents = checkpoint.intents.lock().await;
584        assert_eq!(intents.len(), 1);
585        assert_eq!(intents[0].0.name, "publish");
586
587        // Outcome was recorded
588        let outcomes = checkpoint.outcomes.lock().await;
589        assert_eq!(outcomes.len(), 1);
590        assert!(outcomes.values().next().unwrap());
591
592        // No pending intents (outcome recorded)
593        drop(intents);
594        drop(outcomes);
595        let pending = checkpoint.pending_intents().await.unwrap();
596        assert!(pending.is_empty());
597    }
598
599    #[tokio::test]
600    async fn test_simulated_crash_leaves_pending_intent() {
601        let checkpoint = Arc::new(InMemoryCheckpoint::new());
602
603        // Manually record an intent with no outcome (simulates crash)
604        let step = final_step("deploy");
605        let key = IdempotencyKey::generate();
606        checkpoint.record_intent(&step, &key).await.unwrap();
607
608        // No outcome recorded → pending_intents should return it
609        let pending = checkpoint.pending_intents().await.unwrap();
610        assert_eq!(pending.len(), 1);
611        assert_eq!(pending[0].0.name, "deploy");
612        assert_eq!(pending[0].1, key);
613    }
614
615    #[tokio::test]
616    async fn test_recovery_re_executes_pending_final_step() {
617        let checkpoint = Arc::new(InMemoryCheckpoint::new());
618
619        // Simulate a crash: intent recorded, no outcome
620        let step = final_step("deploy");
621        let key = IdempotencyKey::generate();
622        checkpoint.record_intent(&step, &key).await.unwrap();
623
624        // Recover
625        let results = SagaOrchestrator::recover(checkpoint.as_ref(), |_tool, _args| async {
626            Ok("recovered".to_string())
627        })
628        .await
629        .unwrap();
630
631        assert_eq!(results.len(), 1);
632        assert!(results[0].success);
633        assert_eq!(results[0].name, "recover:deploy");
634        assert_eq!(results[0].output, "recovered");
635
636        // After recovery, no more pending intents
637        let pending = checkpoint.pending_intents().await.unwrap();
638        assert!(pending.is_empty());
639    }
640
641    #[tokio::test]
642    async fn test_idempotency_key_prevents_double_execution() {
643        let checkpoint = Arc::new(InMemoryCheckpoint::new());
644
645        // Record intent + outcome (completed step)
646        let step = final_step("deploy");
647        let key = IdempotencyKey::generate();
648        checkpoint.record_intent(&step, &key).await.unwrap();
649        checkpoint.record_outcome(&key, true).await.unwrap();
650
651        // Recovery should find nothing pending
652        let results = SagaOrchestrator::recover(checkpoint.as_ref(), |_tool, _args| async {
653            Ok("should not run".to_string())
654        })
655        .await
656        .unwrap();
657
658        assert!(results.is_empty());
659    }
660
661    #[tokio::test]
662    async fn test_execute_without_checkpoint_backward_compat() {
663        let saga =
664            SagaOrchestrator::new(vec![write_step("create"), final_step("publish")]).unwrap();
665
666        // Original execute() still works without any checkpoint
667        let result = saga
668            .execute(|_tool, _args| async { Ok("ok".to_string()) })
669            .await;
670
671        assert!(result.success);
672        assert_eq!(result.step_results.len(), 2);
673    }
674}