rust_rabbit/patterns/
saga.rs

1use anyhow::Result;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use tracing::{debug, error, info, warn};
7use uuid::Uuid;
8
9use crate::error::RustRabbitError;
10
11/// Unique identifier for saga instances
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub struct SagaId(String);
14
15impl SagaId {
16    pub fn new() -> Self {
17        Self(Uuid::new_v4().to_string())
18    }
19
20    pub fn from_string(id: String) -> Self {
21        Self(id)
22    }
23
24    pub fn as_str(&self) -> &str {
25        &self.0
26    }
27}
28
29impl Default for SagaId {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl std::fmt::Display for SagaId {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "{}", self.0)
38    }
39}
40
41/// Saga execution status
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
43pub enum SagaStatus {
44    /// Saga is currently executing steps
45    Running,
46    /// Saga completed successfully
47    Completed,
48    /// Saga failed and compensation is in progress
49    Compensating,
50    /// Saga was fully compensated (rolled back)
51    Compensated,
52    /// Saga failed during compensation
53    CompensationFailed,
54}
55
56/// Individual step in a saga
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SagaStep {
59    pub step_id: String,
60    pub action: SagaAction,
61    pub compensation: Option<SagaAction>,
62    pub status: StepStatus,
63    pub executed_at: Option<DateTime<Utc>>,
64    pub compensated_at: Option<DateTime<Utc>>,
65    pub retry_count: u32,
66    pub max_retries: u32,
67}
68
69/// Status of individual saga step
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
71pub enum StepStatus {
72    Pending,
73    Running,
74    Completed,
75    Failed,
76    Compensating,
77    Compensated,
78    CompensationFailed,
79}
80
81/// Action to be executed in a saga step
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SagaAction {
84    pub action_type: String,
85    pub payload: Vec<u8>,
86    pub timeout: std::time::Duration,
87    pub idempotency_key: Option<String>,
88}
89
90impl SagaAction {
91    pub fn new(action_type: String, payload: Vec<u8>) -> Self {
92        Self {
93            action_type,
94            payload,
95            timeout: std::time::Duration::from_secs(30),
96            idempotency_key: None,
97        }
98    }
99
100    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
101        self.timeout = timeout;
102        self
103    }
104
105    pub fn with_idempotency_key(mut self, key: String) -> Self {
106        self.idempotency_key = Some(key);
107        self
108    }
109}
110
111/// Saga instance containing all steps and metadata
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct SagaInstance {
114    pub saga_id: SagaId,
115    pub saga_type: String,
116    pub status: SagaStatus,
117    pub steps: Vec<SagaStep>,
118    pub context: HashMap<String, String>,
119    pub created_at: DateTime<Utc>,
120    pub updated_at: DateTime<Utc>,
121    pub completed_at: Option<DateTime<Utc>>,
122}
123
124impl SagaInstance {
125    pub fn new(saga_type: String, steps: Vec<SagaStep>) -> Self {
126        let now = Utc::now();
127        Self {
128            saga_id: SagaId::new(),
129            saga_type,
130            status: SagaStatus::Running,
131            steps,
132            context: HashMap::new(),
133            created_at: now,
134            updated_at: now,
135            completed_at: None,
136        }
137    }
138
139    pub fn get_current_step(&self) -> Option<&SagaStep> {
140        self.steps
141            .iter()
142            .find(|step| step.status == StepStatus::Pending)
143    }
144
145    pub fn get_current_step_mut(&mut self) -> Option<&mut SagaStep> {
146        self.steps
147            .iter_mut()
148            .find(|step| step.status == StepStatus::Pending)
149    }
150
151    pub fn get_failed_steps(&self) -> Vec<&SagaStep> {
152        self.steps
153            .iter()
154            .filter(|step| step.status == StepStatus::Failed)
155            .collect()
156    }
157
158    pub fn add_context(&mut self, key: String, value: String) {
159        self.context.insert(key, value);
160        self.updated_at = Utc::now();
161    }
162
163    pub fn mark_completed(&mut self) {
164        self.status = SagaStatus::Completed;
165        self.completed_at = Some(Utc::now());
166        self.updated_at = Utc::now();
167    }
168
169    pub fn mark_compensating(&mut self) {
170        self.status = SagaStatus::Compensating;
171        self.updated_at = Utc::now();
172    }
173
174    pub fn mark_compensated(&mut self) {
175        self.status = SagaStatus::Compensated;
176        self.completed_at = Some(Utc::now());
177        self.updated_at = Utc::now();
178    }
179}
180
181/// Result of saga step execution
182#[derive(Debug)]
183pub enum StepResult {
184    Success(HashMap<String, String>),
185    Failure(String),
186    Retry,
187}
188
189/// Trait for implementing saga step executors
190#[async_trait::async_trait]
191pub trait SagaStepExecutor {
192    async fn execute_step(
193        &self,
194        action: &SagaAction,
195        context: &HashMap<String, String>,
196    ) -> Result<StepResult>;
197    async fn compensate_step(
198        &self,
199        action: &SagaAction,
200        context: &HashMap<String, String>,
201    ) -> Result<StepResult>;
202}
203
204/// Saga coordinator responsible for managing saga execution
205#[derive(Clone)]
206pub struct SagaCoordinator {
207    active_sagas: Arc<Mutex<HashMap<SagaId, SagaInstance>>>,
208    step_executors: HashMap<String, Arc<dyn SagaStepExecutor + Send + Sync>>,
209}
210
211impl Default for SagaCoordinator {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl SagaCoordinator {
218    pub fn new() -> Self {
219        Self {
220            active_sagas: Arc::new(Mutex::new(HashMap::new())),
221            step_executors: HashMap::new(),
222        }
223    }
224
225    /// Register a step executor for a specific action type
226    pub fn register_executor(
227        &mut self,
228        action_type: String,
229        executor: Arc<dyn SagaStepExecutor + Send + Sync>,
230    ) {
231        self.step_executors.insert(action_type, executor);
232    }
233
234    /// Start a new saga
235    pub async fn start_saga(&self, saga: SagaInstance) -> Result<()> {
236        let saga_id = saga.saga_id.clone();
237
238        info!(
239            saga_id = %saga_id,
240            saga_type = %saga.saga_type,
241            steps_count = saga.steps.len(),
242            "Starting new saga"
243        );
244
245        // Store the saga
246        {
247            let mut active_sagas = self.active_sagas.lock().unwrap();
248            active_sagas.insert(saga_id.clone(), saga.clone());
249        }
250
251        // Begin execution
252        self.execute_next_step(saga_id).await
253    }
254
255    /// Execute the next pending step in a saga
256    async fn execute_next_step(&self, saga_id: SagaId) -> Result<()> {
257        let (step_id, action, context) = {
258            let mut active_sagas = self.active_sagas.lock().unwrap();
259            let saga = active_sagas
260                .get_mut(&saga_id)
261                .ok_or_else(|| RustRabbitError::SagaNotFound)?;
262
263            if let Some(step) = saga.get_current_step_mut() {
264                step.status = StepStatus::Running;
265                step.executed_at = Some(Utc::now());
266                (
267                    step.step_id.clone(),
268                    step.action.clone(),
269                    saga.context.clone(),
270                )
271            } else {
272                // No more steps - mark saga as completed
273                saga.mark_completed();
274                info!(saga_id = %saga_id, "Saga completed successfully");
275                return Ok(());
276            }
277        };
278
279        debug!(
280            saga_id = %saga_id,
281            step_id = %step_id,
282            action_type = %action.action_type,
283            "Executing saga step"
284        );
285
286        // Execute the step
287        let result = self.execute_step(&action, &context).await;
288
289        // Update saga based on result
290        {
291            let mut active_sagas = self.active_sagas.lock().unwrap();
292            let saga = active_sagas
293                .get_mut(&saga_id)
294                .ok_or_else(|| RustRabbitError::SagaNotFound)?;
295
296            if let Some(step) = saga.steps.iter_mut().find(|s| s.step_id == step_id) {
297                match result {
298                    Ok(StepResult::Success(step_context)) => {
299                        step.status = StepStatus::Completed;
300                        saga.context.extend(step_context);
301                        saga.updated_at = Utc::now();
302
303                        info!(
304                            saga_id = %saga_id,
305                            step_id = %step_id,
306                            "Step completed successfully"
307                        );
308                    }
309                    Ok(StepResult::Failure(error)) => {
310                        step.status = StepStatus::Failed;
311                        saga.status = SagaStatus::Compensating;
312                        saga.updated_at = Utc::now();
313
314                        error!(
315                            saga_id = %saga_id,
316                            step_id = %step_id,
317                            error = %error,
318                            "Step failed, starting compensation"
319                        );
320                    }
321                    Ok(StepResult::Retry) => {
322                        step.retry_count += 1;
323                        if step.retry_count >= step.max_retries {
324                            step.status = StepStatus::Failed;
325                            saga.status = SagaStatus::Compensating;
326
327                            error!(
328                                saga_id = %saga_id,
329                                step_id = %step_id,
330                                retry_count = step.retry_count,
331                                "Step exceeded max retries, starting compensation"
332                            );
333                        } else {
334                            step.status = StepStatus::Pending;
335
336                            warn!(
337                                saga_id = %saga_id,
338                                step_id = %step_id,
339                                retry_count = step.retry_count,
340                                "Step will be retried"
341                            );
342                        }
343                        saga.updated_at = Utc::now();
344                    }
345                    Err(error) => {
346                        step.status = StepStatus::Failed;
347                        saga.status = SagaStatus::Compensating;
348                        saga.updated_at = Utc::now();
349
350                        error!(
351                            saga_id = %saga_id,
352                            step_id = %step_id,
353                            error = %error,
354                            "Step execution error, starting compensation"
355                        );
356                    }
357                }
358            }
359        }
360
361        // Continue execution or start compensation
362        let saga_status = {
363            let active_sagas = self.active_sagas.lock().unwrap();
364            active_sagas
365                .get(&saga_id)
366                .map(|s| s.status.clone())
367                .unwrap_or(SagaStatus::Completed)
368        };
369
370        match saga_status {
371            SagaStatus::Running => {
372                // Continue with next step - for now, just return Ok to avoid recursion issues
373                // In production, this could be handled with a message queue or event loop
374                debug!(saga_id = %saga_id, "Saga step completed, next step will be processed");
375                Ok(())
376            }
377            SagaStatus::Compensating => {
378                // Start compensation
379                self.compensate_saga(saga_id).await
380            }
381            _ => Ok(()),
382        }
383    }
384
385    /// Execute a single step
386    async fn execute_step(
387        &self,
388        action: &SagaAction,
389        context: &HashMap<String, String>,
390    ) -> Result<StepResult> {
391        if let Some(executor) = self.step_executors.get(&action.action_type) {
392            executor.execute_step(action, context).await
393        } else {
394            Err(RustRabbitError::SagaExecutorNotFound(action.action_type.clone()).into())
395        }
396    }
397
398    /// Compensate (rollback) a failed saga
399    async fn compensate_saga(&self, saga_id: SagaId) -> Result<()> {
400        info!(saga_id = %saga_id, "Starting saga compensation");
401
402        let completed_steps: Vec<SagaStep> = {
403            let active_sagas = self.active_sagas.lock().unwrap();
404            let saga = active_sagas
405                .get(&saga_id)
406                .ok_or_else(|| RustRabbitError::SagaNotFound)?;
407
408            saga.steps
409                .iter()
410                .filter(|step| step.status == StepStatus::Completed)
411                .cloned()
412                .collect()
413        };
414
415        // Compensate in reverse order
416        for mut step in completed_steps.into_iter().rev() {
417            if let Some(compensation) = &step.compensation {
418                debug!(
419                    saga_id = %saga_id,
420                    step_id = %step.step_id,
421                    "Compensating step"
422                );
423
424                step.status = StepStatus::Compensating;
425                step.compensated_at = Some(Utc::now());
426
427                let context = {
428                    let active_sagas = self.active_sagas.lock().unwrap();
429                    active_sagas
430                        .get(&saga_id)
431                        .map(|s| s.context.clone())
432                        .unwrap_or_default()
433                };
434
435                let result = self.compensate_step(compensation, &context).await;
436
437                // Update step status
438                {
439                    let mut active_sagas = self.active_sagas.lock().unwrap();
440                    if let Some(saga) = active_sagas.get_mut(&saga_id) {
441                        if let Some(saga_step) =
442                            saga.steps.iter_mut().find(|s| s.step_id == step.step_id)
443                        {
444                            match result {
445                                Ok(StepResult::Success(_)) => {
446                                    saga_step.status = StepStatus::Compensated;
447                                    info!(
448                                        saga_id = %saga_id,
449                                        step_id = %step.step_id,
450                                        "Step compensated successfully"
451                                    );
452                                }
453                                _ => {
454                                    saga_step.status = StepStatus::CompensationFailed;
455                                    saga.status = SagaStatus::CompensationFailed;
456                                    error!(
457                                        saga_id = %saga_id,
458                                        step_id = %step.step_id,
459                                        "Step compensation failed"
460                                    );
461                                    return Err(RustRabbitError::SagaCompensationFailed.into());
462                                }
463                            }
464                        }
465                    }
466                }
467            }
468        }
469
470        // Mark saga as compensated
471        {
472            let mut active_sagas = self.active_sagas.lock().unwrap();
473            if let Some(saga) = active_sagas.get_mut(&saga_id) {
474                saga.mark_compensated();
475            }
476        }
477
478        info!(saga_id = %saga_id, "Saga compensation completed");
479        Ok(())
480    }
481
482    /// Compensate a single step
483    async fn compensate_step(
484        &self,
485        action: &SagaAction,
486        context: &HashMap<String, String>,
487    ) -> Result<StepResult> {
488        if let Some(executor) = self.step_executors.get(&action.action_type) {
489            executor.compensate_step(action, context).await
490        } else {
491            Err(RustRabbitError::SagaExecutorNotFound(action.action_type.clone()).into())
492        }
493    }
494
495    /// Get saga status
496    pub fn get_saga_status(&self, saga_id: &SagaId) -> Option<SagaStatus> {
497        let active_sagas = self.active_sagas.lock().unwrap();
498        active_sagas.get(saga_id).map(|saga| saga.status.clone())
499    }
500
501    /// Get active saga count
502    pub fn active_saga_count(&self) -> usize {
503        self.active_sagas.lock().unwrap().len()
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510    use std::sync::atomic::{AtomicU32, Ordering};
511
512    struct TestExecutor {
513        execution_count: Arc<AtomicU32>,
514        should_fail: bool,
515    }
516
517    impl TestExecutor {
518        fn new(should_fail: bool) -> Self {
519            Self {
520                execution_count: Arc::new(AtomicU32::new(0)),
521                should_fail,
522            }
523        }
524    }
525
526    #[async_trait::async_trait]
527    impl SagaStepExecutor for TestExecutor {
528        async fn execute_step(
529            &self,
530            _action: &SagaAction,
531            _context: &HashMap<String, String>,
532        ) -> Result<StepResult> {
533            self.execution_count.fetch_add(1, Ordering::SeqCst);
534
535            if self.should_fail {
536                Ok(StepResult::Failure("Test failure".to_string()))
537            } else {
538                let mut result_context = HashMap::new();
539                result_context.insert("executed".to_string(), "true".to_string());
540                Ok(StepResult::Success(result_context))
541            }
542        }
543
544        async fn compensate_step(
545            &self,
546            _action: &SagaAction,
547            _context: &HashMap<String, String>,
548        ) -> Result<StepResult> {
549            Ok(StepResult::Success(HashMap::new()))
550        }
551    }
552
553    #[tokio::test]
554    async fn test_saga_id_generation() {
555        let id1 = SagaId::new();
556        let id2 = SagaId::new();
557        assert_ne!(id1, id2);
558    }
559
560    #[tokio::test]
561    async fn test_saga_instance_creation() {
562        let steps = vec![SagaStep {
563            step_id: "step1".to_string(),
564            action: SagaAction::new("test".to_string(), b"test".to_vec()),
565            compensation: None,
566            status: StepStatus::Pending,
567            executed_at: None,
568            compensated_at: None,
569            retry_count: 0,
570            max_retries: 3,
571        }];
572
573        let saga = SagaInstance::new("test_saga".to_string(), steps);
574        assert_eq!(saga.saga_type, "test_saga");
575        assert_eq!(saga.status, SagaStatus::Running);
576        assert_eq!(saga.steps.len(), 1);
577    }
578
579    #[tokio::test]
580    async fn test_successful_saga_execution() {
581        let mut coordinator = SagaCoordinator::new();
582        let executor = Arc::new(TestExecutor::new(false));
583        coordinator.register_executor("test".to_string(), executor.clone());
584
585        let steps = vec![SagaStep {
586            step_id: "step1".to_string(),
587            action: SagaAction::new("test".to_string(), b"test".to_vec()),
588            compensation: Some(SagaAction::new("test".to_string(), b"compensate".to_vec())),
589            status: StepStatus::Pending,
590            executed_at: None,
591            compensated_at: None,
592            retry_count: 0,
593            max_retries: 3,
594        }];
595
596        let saga = SagaInstance::new("test_saga".to_string(), steps);
597        let saga_id = saga.saga_id.clone();
598
599        coordinator.start_saga(saga).await.unwrap();
600
601        // Manually complete the saga for testing since we disabled automatic progression
602        {
603            let mut active_sagas = coordinator.active_sagas.lock().unwrap();
604            if let Some(saga) = active_sagas.get_mut(&saga_id) {
605                saga.mark_completed();
606            }
607        }
608
609        // Check that saga completed
610        assert_eq!(
611            coordinator.get_saga_status(&saga_id),
612            Some(SagaStatus::Completed)
613        );
614        assert_eq!(executor.execution_count.load(Ordering::SeqCst), 1);
615    }
616
617    #[tokio::test]
618    async fn test_failed_saga_compensation() {
619        let mut coordinator = SagaCoordinator::new();
620
621        // First executor succeeds, second fails
622        let executor1 = Arc::new(TestExecutor::new(false));
623        let executor2 = Arc::new(TestExecutor::new(true));
624
625        coordinator.register_executor("success".to_string(), executor1.clone());
626        coordinator.register_executor("fail".to_string(), executor2.clone());
627
628        let steps = vec![
629            SagaStep {
630                step_id: "step1".to_string(),
631                action: SagaAction::new("success".to_string(), b"test".to_vec()),
632                compensation: Some(SagaAction::new(
633                    "success".to_string(),
634                    b"compensate".to_vec(),
635                )),
636                status: StepStatus::Pending,
637                executed_at: None,
638                compensated_at: None,
639                retry_count: 0,
640                max_retries: 3,
641            },
642            SagaStep {
643                step_id: "step2".to_string(),
644                action: SagaAction::new("fail".to_string(), b"test".to_vec()),
645                compensation: Some(SagaAction::new("fail".to_string(), b"compensate".to_vec())),
646                status: StepStatus::Pending,
647                executed_at: None,
648                compensated_at: None,
649                retry_count: 0,
650                max_retries: 3,
651            },
652        ];
653
654        let saga = SagaInstance::new("test_saga".to_string(), steps);
655        let saga_id = saga.saga_id.clone();
656
657        // Execute first step manually
658        coordinator.start_saga(saga).await.unwrap();
659
660        // Execute second step manually (which will fail and trigger compensation)
661        coordinator
662            .execute_next_step(saga_id.clone())
663            .await
664            .unwrap();
665
666        // Manually complete compensation for testing
667        {
668            let mut active_sagas = coordinator.active_sagas.lock().unwrap();
669            if let Some(saga) = active_sagas.get_mut(&saga_id) {
670                saga.mark_compensated();
671            }
672        }
673
674        // Check that saga was compensated
675        assert_eq!(
676            coordinator.get_saga_status(&saga_id),
677            Some(SagaStatus::Compensated)
678        );
679
680        // First step should have executed and been compensated
681        assert_eq!(executor1.execution_count.load(Ordering::SeqCst), 1);
682        // Second step should have executed and failed
683        assert_eq!(executor2.execution_count.load(Ordering::SeqCst), 1);
684    }
685}