rust_logic_graph/saga/
mod.rs1use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9#[derive(Debug, Serialize, Deserialize, Clone)]
10pub enum SagaStepStatus {
11 Pending,
12 Completed,
13 Failed,
14 Compensated,
15 Aborted,
16}
17
18pub struct SagaStep {
19 pub id: String,
20 pub action: Box<dyn Fn(&mut SagaContext) -> Result<()>>,
21 pub compensation: Option<Box<dyn Fn(&mut SagaContext) -> Result<()>>>,
22 pub status: SagaStepStatus,
23 pub timeout: Option<Duration>,
24}
25
26#[derive(Debug, Serialize, Deserialize, Default, Clone)]
27pub struct SagaContext {
28 pub data: HashMap<String, serde_json::Value>,
29}
30
31#[derive(Debug)]
32pub struct SagaState {
33 pub steps: Vec<SagaStepStatus>,
34 pub started_at: Instant,
35 pub finished_at: Option<Instant>,
36 pub aborted: bool,
37}
38
39pub struct SagaCoordinator {
40 pub steps: Vec<SagaStep>,
41 pub context: SagaContext,
42 pub state: SagaState,
43 pub deadline: Option<Instant>,
44}
45
46impl SagaCoordinator {
47 pub fn new(deadline: Option<Duration>) -> Self {
48 let now = Instant::now();
49 Self {
50 steps: Vec::new(),
51 context: SagaContext::default(),
52 state: SagaState {
53 steps: Vec::new(),
54 started_at: now,
55 finished_at: None,
56 aborted: false,
57 },
58 deadline: deadline.map(|d| now + d),
59 }
60 }
61
62 pub fn add_step(&mut self, step: SagaStep) {
63 self.steps.push(step);
64 self.state.steps.push(SagaStepStatus::Pending);
65 }
66
67 pub fn execute(&mut self) -> Result<()> {
68 for (i, step) in self.steps.iter_mut().enumerate() {
69 if let Some(deadline) = self.deadline {
70 if Instant::now() > deadline {
71 self.state.aborted = true;
72 self.state.finished_at = Some(Instant::now());
73 self.compensate(i)?;
74 return Err(anyhow::anyhow!("Saga deadline exceeded"));
75 }
76 }
77 let step_start = Instant::now();
78 let timeout = step.timeout;
79 let result = (step.action)(&mut self.context);
80 if let Err(e) = result {
81 self.state.steps[i] = SagaStepStatus::Failed;
82 self.compensate(i)?;
83 self.state.aborted = true;
84 self.state.finished_at = Some(Instant::now());
85 return Err(e);
86 }
87 if let Some(t) = timeout {
88 if step_start.elapsed() > t {
89 self.state.steps[i] = SagaStepStatus::Aborted;
90 self.compensate(i)?;
91 self.state.aborted = true;
92 self.state.finished_at = Some(Instant::now());
93 return Err(anyhow::anyhow!("Step timeout exceeded"));
94 }
95 }
96 self.state.steps[i] = SagaStepStatus::Completed;
97 }
98 self.state.finished_at = Some(Instant::now());
99 Ok(())
100 }
101
102 fn compensate(&mut self, failed_index: usize) -> Result<()> {
103 for (i, step) in self.steps[..=failed_index].iter_mut().rev().enumerate() {
104 if let Some(comp) = &step.compensation {
105 let result = (comp)(&mut self.context);
106 if result.is_ok() {
107 self.state.steps[failed_index - i] = SagaStepStatus::Compensated;
108 }
109 }
110 }
111 Ok(())
112 }
113}