1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub enum StepClassification {
15 ReadOnly,
17 Compensatable,
19 Final,
21}
22
23#[derive(Debug, Clone)]
25pub struct SagaStep {
26 pub name: String,
28 pub classification: StepClassification,
30 pub action: SagaAction,
32 pub compensation: Option<SagaAction>,
34}
35
36#[derive(Debug, Clone)]
38pub struct SagaAction {
39 pub tool_name: String,
41 pub arguments: serde_json::Value,
43}
44
45#[derive(Debug, Clone)]
47pub struct StepResult {
48 pub name: String,
50 pub success: bool,
52 pub output: String,
54 pub error: Option<String>,
56}
57
58#[derive(Debug, Clone)]
60pub struct SagaResult {
61 pub success: bool,
63 pub step_results: Vec<StepResult>,
65 pub compensation_results: Vec<StepResult>,
67 pub summary: String,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Hash)]
76pub struct IdempotencyKey(pub String);
77
78impl IdempotencyKey {
79 pub fn generate() -> Self {
81 Self(uuid::Uuid::new_v4().to_string())
82 }
83}
84
85#[async_trait::async_trait]
90pub trait SagaCheckpoint: Send + Sync {
91 async fn record_intent(&self, step: &SagaStep, key: &IdempotencyKey) -> Result<(), SagaError>;
93
94 async fn record_outcome(&self, key: &IdempotencyKey, success: bool) -> Result<(), SagaError>;
96
97 async fn pending_intents(&self) -> Result<Vec<(SagaStep, IdempotencyKey)>, SagaError>;
99}
100
101pub struct InMemoryCheckpoint {
106 intents: Mutex<Vec<(SagaStep, IdempotencyKey)>>,
107 outcomes: Mutex<HashMap<String, bool>>,
108}
109
110impl Default for InMemoryCheckpoint {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl InMemoryCheckpoint {
117 pub fn new() -> Self {
118 Self {
119 intents: Mutex::new(Vec::new()),
120 outcomes: Mutex::new(HashMap::new()),
121 }
122 }
123}
124
125#[async_trait::async_trait]
126impl SagaCheckpoint for InMemoryCheckpoint {
127 async fn record_intent(&self, step: &SagaStep, key: &IdempotencyKey) -> Result<(), SagaError> {
128 self.intents.lock().await.push((step.clone(), key.clone()));
129 Ok(())
130 }
131
132 async fn record_outcome(&self, key: &IdempotencyKey, success: bool) -> Result<(), SagaError> {
133 self.outcomes.lock().await.insert(key.0.clone(), success);
134 Ok(())
135 }
136
137 async fn pending_intents(&self) -> Result<Vec<(SagaStep, IdempotencyKey)>, SagaError> {
138 let intents = self.intents.lock().await;
139 let outcomes = self.outcomes.lock().await;
140 Ok(intents
141 .iter()
142 .filter(|(_, key)| !outcomes.contains_key(&key.0))
143 .cloned()
144 .collect())
145 }
146}
147
148pub struct SagaOrchestrator {
150 steps: Vec<SagaStep>,
151}
152
153impl SagaOrchestrator {
154 pub fn new(steps: Vec<SagaStep>) -> Result<Self, SagaError> {
158 let mut seen_final = false;
160 for step in &steps {
161 if seen_final && step.classification != StepClassification::Final {
162 return Err(SagaError::InvalidStepOrder {
163 step: step.name.clone(),
164 reason: "Non-final steps cannot appear after a Final step".into(),
165 });
166 }
167 if step.classification == StepClassification::Final {
168 seen_final = true;
169 }
170 }
171
172 for step in &steps {
174 if step.classification == StepClassification::Compensatable
175 && step.compensation.is_none()
176 {
177 return Err(SagaError::MissingCompensation {
178 step: step.name.clone(),
179 });
180 }
181 }
182
183 Ok(Self { steps })
184 }
185
186 pub async fn execute<F, Fut>(&self, executor: F) -> SagaResult
190 where
191 F: Fn(String, serde_json::Value) -> Fut,
192 Fut: std::future::Future<Output = Result<String, String>>,
193 {
194 self.execute_inner(&executor, None).await
195 }
196
197 pub async fn execute_with_checkpoint<F, Fut>(
203 &self,
204 executor: F,
205 checkpoint: Arc<dyn SagaCheckpoint>,
206 ) -> SagaResult
207 where
208 F: Fn(String, serde_json::Value) -> Fut,
209 Fut: std::future::Future<Output = Result<String, String>>,
210 {
211 self.execute_inner(&executor, Some(checkpoint)).await
212 }
213
214 async fn execute_inner<F, Fut>(
215 &self,
216 executor: &F,
217 checkpoint: Option<Arc<dyn SagaCheckpoint>>,
218 ) -> SagaResult
219 where
220 F: Fn(String, serde_json::Value) -> Fut,
221 Fut: std::future::Future<Output = Result<String, String>>,
222 {
223 let mut step_results = Vec::new();
224 let mut completed_compensatable: Vec<&SagaStep> = Vec::new();
225
226 for step in &self.steps {
227 let idempotency_key = if step.classification == StepClassification::Final {
229 if let Some(cp) = &checkpoint {
230 let key = IdempotencyKey::generate();
231 if let Err(e) = cp.record_intent(step, &key).await {
232 return SagaResult {
233 success: false,
234 step_results,
235 compensation_results: self
236 .compensate(&completed_compensatable, executor)
237 .await,
238 summary: format!(
239 "Failed to record intent for Final step '{}': {}",
240 step.name, e
241 ),
242 };
243 }
244 Some(key)
245 } else {
246 None
247 }
248 } else {
249 None
250 };
251
252 let result =
253 executor(step.action.tool_name.clone(), step.action.arguments.clone()).await;
254
255 match result {
256 Ok(output) => {
257 if let Some(key) = &idempotency_key {
259 if let Some(cp) = &checkpoint {
260 let _ = cp.record_outcome(key, true).await;
261 }
262 }
263
264 step_results.push(StepResult {
265 name: step.name.clone(),
266 success: true,
267 output,
268 error: None,
269 });
270
271 if step.classification == StepClassification::Compensatable {
272 completed_compensatable.push(step);
273 }
274 }
275 Err(error) => {
276 if let Some(key) = &idempotency_key {
278 if let Some(cp) = &checkpoint {
279 let _ = cp.record_outcome(key, false).await;
280 }
281 }
282
283 step_results.push(StepResult {
284 name: step.name.clone(),
285 success: false,
286 output: String::new(),
287 error: Some(error.clone()),
288 });
289
290 let compensation_results =
292 self.compensate(&completed_compensatable, executor).await;
293
294 return SagaResult {
295 success: false,
296 step_results,
297 compensation_results,
298 summary: format!("Saga failed at step '{}': {}", step.name, error),
299 };
300 }
301 }
302 }
303
304 SagaResult {
305 success: true,
306 step_results,
307 compensation_results: Vec::new(),
308 summary: "Saga completed successfully".into(),
309 }
310 }
311
312 pub async fn recover<F, Fut>(
318 checkpoint: &dyn SagaCheckpoint,
319 executor: F,
320 ) -> Result<Vec<StepResult>, SagaError>
321 where
322 F: Fn(String, serde_json::Value) -> Fut,
323 Fut: std::future::Future<Output = Result<String, String>>,
324 {
325 let pending = checkpoint.pending_intents().await?;
326 let mut results = Vec::new();
327
328 for (step, key) in &pending {
329 let result =
330 executor(step.action.tool_name.clone(), step.action.arguments.clone()).await;
331
332 let success = result.is_ok();
333 let _ = checkpoint.record_outcome(key, success).await;
334
335 results.push(StepResult {
336 name: format!("recover:{}", step.name),
337 success,
338 output: result.as_ref().cloned().unwrap_or_default(),
339 error: result.err(),
340 });
341 }
342
343 Ok(results)
344 }
345
346 async fn compensate<F, Fut>(&self, completed: &[&SagaStep], executor: &F) -> Vec<StepResult>
347 where
348 F: Fn(String, serde_json::Value) -> Fut,
349 Fut: std::future::Future<Output = Result<String, String>>,
350 {
351 let mut results = Vec::new();
352
353 for step in completed.iter().rev() {
355 if let Some(compensation) = &step.compensation {
356 let result = executor(
357 compensation.tool_name.clone(),
358 compensation.arguments.clone(),
359 )
360 .await;
361
362 results.push(StepResult {
363 name: format!("compensate:{}", step.name),
364 success: result.is_ok(),
365 output: result.as_ref().cloned().unwrap_or_default(),
366 error: result.err(),
367 });
368 }
369 }
370
371 results
372 }
373
374 pub fn step_count(&self) -> usize {
376 self.steps.len()
377 }
378
379 pub fn steps_by_classification(&self) -> HashMap<String, usize> {
381 let mut counts = HashMap::new();
382 for step in &self.steps {
383 let key = format!("{:?}", step.classification);
384 *counts.entry(key).or_insert(0) += 1;
385 }
386 counts
387 }
388}
389
390#[derive(Debug, thiserror::Error)]
392pub enum SagaError {
393 #[error("Invalid step order for '{step}': {reason}")]
394 InvalidStepOrder { step: String, reason: String },
395
396 #[error("Compensatable step '{step}' is missing a compensation action")]
397 MissingCompensation { step: String },
398
399 #[error("Checkpoint operation failed: {0}")]
400 CheckpointFailed(String),
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 fn read_step(name: &str) -> SagaStep {
408 SagaStep {
409 name: name.into(),
410 classification: StepClassification::ReadOnly,
411 action: SagaAction {
412 tool_name: "read".into(),
413 arguments: serde_json::json!({"path": name}),
414 },
415 compensation: None,
416 }
417 }
418
419 fn write_step(name: &str) -> SagaStep {
420 SagaStep {
421 name: name.into(),
422 classification: StepClassification::Compensatable,
423 action: SagaAction {
424 tool_name: "write".into(),
425 arguments: serde_json::json!({"path": name, "data": "content"}),
426 },
427 compensation: Some(SagaAction {
428 tool_name: "delete".into(),
429 arguments: serde_json::json!({"path": name}),
430 }),
431 }
432 }
433
434 fn final_step(name: &str) -> SagaStep {
435 SagaStep {
436 name: name.into(),
437 classification: StepClassification::Final,
438 action: SagaAction {
439 tool_name: "publish".into(),
440 arguments: serde_json::json!({"target": name}),
441 },
442 compensation: None,
443 }
444 }
445
446 #[test]
447 fn test_valid_saga_creation() {
448 let saga = SagaOrchestrator::new(vec![
449 read_step("check"),
450 write_step("create"),
451 write_step("update"),
452 final_step("publish"),
453 ]);
454 assert!(saga.is_ok());
455 assert_eq!(saga.unwrap().step_count(), 4);
456 }
457
458 #[test]
459 fn test_invalid_order_non_final_after_final() {
460 let saga = SagaOrchestrator::new(vec![
461 write_step("create"),
462 final_step("publish"),
463 read_step("check"), ]);
465 assert!(saga.is_err());
466 }
467
468 #[test]
469 fn test_missing_compensation() {
470 let step = SagaStep {
471 name: "bad".into(),
472 classification: StepClassification::Compensatable,
473 action: SagaAction {
474 tool_name: "write".into(),
475 arguments: serde_json::json!({}),
476 },
477 compensation: None, };
479 let saga = SagaOrchestrator::new(vec![step]);
480 assert!(saga.is_err());
481 }
482
483 #[tokio::test]
484 async fn test_saga_all_succeed() {
485 let saga = SagaOrchestrator::new(vec![
486 read_step("check"),
487 write_step("create"),
488 write_step("update"),
489 ])
490 .unwrap();
491
492 let result = saga
493 .execute(|_tool, _args| async { Ok("success".to_string()) })
494 .await;
495
496 assert!(result.success);
497 assert_eq!(result.step_results.len(), 3);
498 assert!(result.compensation_results.is_empty());
499 }
500
501 #[tokio::test]
502 async fn test_saga_fail_at_step_3_compensates_2_and_1() {
503 let saga = SagaOrchestrator::new(vec![
504 read_step("check"),
505 write_step("create"),
506 write_step("update"),
507 write_step("finalize"),
508 final_step("publish"),
509 ])
510 .unwrap();
511
512 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
513 let cc = call_count.clone();
514
515 let result = saga
516 .execute(move |_tool, _args| {
517 let count = cc.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
518 async move {
519 if count == 3 {
520 Err("Connection refused".to_string())
522 } else {
523 Ok("ok".to_string())
524 }
525 }
526 })
527 .await;
528
529 assert!(!result.success);
530 assert_eq!(result.step_results.len(), 4); assert!(!result.step_results[3].success);
532
533 assert_eq!(result.compensation_results.len(), 2);
535 assert_eq!(result.compensation_results[0].name, "compensate:update");
536 assert_eq!(result.compensation_results[1].name, "compensate:create");
537 }
538
539 #[tokio::test]
540 async fn test_saga_fail_at_readonly_no_compensation() {
541 let saga = SagaOrchestrator::new(vec![read_step("check")]).unwrap();
542
543 let result = saga
544 .execute(|_tool, _args| async { Err("fail".to_string()) })
545 .await;
546
547 assert!(!result.success);
548 assert!(result.compensation_results.is_empty());
549 }
550
551 #[test]
552 fn test_steps_by_classification() {
553 let saga = SagaOrchestrator::new(vec![
554 read_step("r1"),
555 write_step("w1"),
556 write_step("w2"),
557 final_step("f1"),
558 ])
559 .unwrap();
560
561 let counts = saga.steps_by_classification();
562 assert_eq!(counts.get("ReadOnly"), Some(&1));
563 assert_eq!(counts.get("Compensatable"), Some(&2));
564 assert_eq!(counts.get("Final"), Some(&1));
565 }
566
567 #[tokio::test]
568 async fn test_final_step_with_checkpoint_records_intent_and_outcome() {
569 let checkpoint = Arc::new(InMemoryCheckpoint::new());
570 let saga =
571 SagaOrchestrator::new(vec![write_step("prepare"), final_step("publish")]).unwrap();
572
573 let result = saga
574 .execute_with_checkpoint(
575 |_tool, _args| async { Ok("done".to_string()) },
576 checkpoint.clone(),
577 )
578 .await;
579
580 assert!(result.success);
581
582 let intents = checkpoint.intents.lock().await;
584 assert_eq!(intents.len(), 1);
585 assert_eq!(intents[0].0.name, "publish");
586
587 let outcomes = checkpoint.outcomes.lock().await;
589 assert_eq!(outcomes.len(), 1);
590 assert!(outcomes.values().next().unwrap());
591
592 drop(intents);
594 drop(outcomes);
595 let pending = checkpoint.pending_intents().await.unwrap();
596 assert!(pending.is_empty());
597 }
598
599 #[tokio::test]
600 async fn test_simulated_crash_leaves_pending_intent() {
601 let checkpoint = Arc::new(InMemoryCheckpoint::new());
602
603 let step = final_step("deploy");
605 let key = IdempotencyKey::generate();
606 checkpoint.record_intent(&step, &key).await.unwrap();
607
608 let pending = checkpoint.pending_intents().await.unwrap();
610 assert_eq!(pending.len(), 1);
611 assert_eq!(pending[0].0.name, "deploy");
612 assert_eq!(pending[0].1, key);
613 }
614
615 #[tokio::test]
616 async fn test_recovery_re_executes_pending_final_step() {
617 let checkpoint = Arc::new(InMemoryCheckpoint::new());
618
619 let step = final_step("deploy");
621 let key = IdempotencyKey::generate();
622 checkpoint.record_intent(&step, &key).await.unwrap();
623
624 let results = SagaOrchestrator::recover(checkpoint.as_ref(), |_tool, _args| async {
626 Ok("recovered".to_string())
627 })
628 .await
629 .unwrap();
630
631 assert_eq!(results.len(), 1);
632 assert!(results[0].success);
633 assert_eq!(results[0].name, "recover:deploy");
634 assert_eq!(results[0].output, "recovered");
635
636 let pending = checkpoint.pending_intents().await.unwrap();
638 assert!(pending.is_empty());
639 }
640
641 #[tokio::test]
642 async fn test_idempotency_key_prevents_double_execution() {
643 let checkpoint = Arc::new(InMemoryCheckpoint::new());
644
645 let step = final_step("deploy");
647 let key = IdempotencyKey::generate();
648 checkpoint.record_intent(&step, &key).await.unwrap();
649 checkpoint.record_outcome(&key, true).await.unwrap();
650
651 let results = SagaOrchestrator::recover(checkpoint.as_ref(), |_tool, _args| async {
653 Ok("should not run".to_string())
654 })
655 .await
656 .unwrap();
657
658 assert!(results.is_empty());
659 }
660
661 #[tokio::test]
662 async fn test_execute_without_checkpoint_backward_compat() {
663 let saga =
664 SagaOrchestrator::new(vec![write_step("create"), final_step("publish")]).unwrap();
665
666 let result = saga
668 .execute(|_tool, _args| async { Ok("ok".to_string()) })
669 .await;
670
671 assert!(result.success);
672 assert_eq!(result.step_results.len(), 2);
673 }
674}